Skip to content

Commit bd38f61

Browse files
authored
Merge pull request #332 from lucienwang1009/onnx_attr_check
Fetch valid attributes from ONNX dynamically
2 parents 510b8d2 + 2ace1ca commit bd38f61

File tree

5 files changed

+100
-20
lines changed

5 files changed

+100
-20
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.coverage
1+
.coverage*
22
*.pyc
33
.idea
44
build

tf2onnx/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import unicode_literals
88

99

10-
__all__ = ["utils", "graph_matcher", "graph", "tfonnx", "shape_inference"]
10+
__all__ = ["utils", "graph_matcher", "graph", "tfonnx", "shape_inference", "schemas"]
1111

1212
from .version import version as __version__
13-
from tf2onnx import tfonnx, utils, graph, graph_matcher, shape_inference # pylint: disable=wrong-import-order
13+
from tf2onnx import tfonnx, utils, graph, graph_matcher, shape_inference, schemas # pylint: disable=wrong-import-order

tf2onnx/graph.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tf2onnx import utils, __version__
2020
from tf2onnx.utils import port_name, find_opset
2121
from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
22+
from tf2onnx.schemas import get_schema
2223

2324

2425
# todo(pengwa): remove protected-access later
@@ -72,8 +73,10 @@ def attr(self):
7273
def attr_onnx(self):
7374
onnx_attrs = {}
7475
for a in self._attr.values():
75-
if a.name in utils.ONNX_VALID_ATTRIBUTES:
76-
onnx_attrs[a.name] = a
76+
schema = get_schema(self.type, self.graph.opset)
77+
if schema:
78+
if schema.has_attribute(a.name):
79+
onnx_attrs[a.name] = a
7780
return onnx_attrs
7881

7982
@property

tf2onnx/schemas.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.schema
6+
"""
7+
8+
from __future__ import division
9+
from __future__ import print_function
10+
from __future__ import unicode_literals
11+
12+
from collections import defaultdict, OrderedDict
13+
from onnx import defs
14+
15+
16+
ONNX_DOMAIN = ""
17+
18+
19+
class OnnxOpSchema(object):
20+
"""Wrapper for Onnx schema."""
21+
22+
def __init__(self, name, domain, since_version, attributes):
23+
"""Create a Onnx schema
24+
Args:
25+
name (str): op name
26+
attributes (List[str]): valid attributes
27+
domain (str): default value "" means it's Onnx domain
28+
since_version (int): opset version, default is 1
29+
"""
30+
self._name = name
31+
self._domain = domain
32+
self._attributes = attributes
33+
self._since_version = since_version
34+
35+
@property
36+
def attributes(self):
37+
return self._attributes
38+
39+
@property
40+
def domain(self):
41+
return self._domain
42+
43+
@property
44+
def name(self):
45+
return self._name
46+
47+
@property
48+
def since_version(self):
49+
return self._since_version
50+
51+
@staticmethod
52+
def from_onnx_schema(onnx_schema):
53+
name = onnx_schema.name
54+
domain = onnx_schema.domain
55+
since_version = int(onnx_schema.since_version)
56+
attributes = onnx_schema.attributes
57+
return OnnxOpSchema(name, domain, since_version, attributes)
58+
59+
def has_attribute(self, attr):
60+
return attr in self.attributes
61+
62+
63+
def _register_all_schemas_with_history():
64+
"""Register all schemas with history"""
65+
onnx_schemas = defs.get_all_schemas_with_history()
66+
name_domain_version_schema_map = defaultdict(lambda: defaultdict(dict))
67+
for s in onnx_schemas:
68+
schema = OnnxOpSchema.from_onnx_schema(s)
69+
name_domain_version_schema_map[schema.name][schema.domain][schema.since_version] = schema
70+
71+
ordered_map = defaultdict(lambda: defaultdict(OrderedDict))
72+
for name, domain_version_schema_map in name_domain_version_schema_map.items():
73+
for domain, version_schema_map in domain_version_schema_map.items():
74+
ordered_map[name][domain] = OrderedDict(
75+
sorted(version_schema_map.items(), key=lambda x: -x[0])
76+
)
77+
return ordered_map
78+
79+
80+
# format is <OpName, <Domain, <SinceVersion, OpSchema>>>
81+
# SinceVersion is sorted from high to low
82+
_schemas = _register_all_schemas_with_history()
83+
84+
85+
def get_schema(name, max_inclusive_opset_version, domain=ONNX_DOMAIN):
86+
"""Get schema by name within specific version."""
87+
domain_version_schema_map = _schemas[name]
88+
version_schema_map = domain_version_schema_map[domain]
89+
for version, schema in version_schema_map.items():
90+
if version <= max_inclusive_opset_version:
91+
return schema
92+
return None

tf2onnx/utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,6 @@
7676

7777
ONNX_UNKNOWN_DIMENSION = -1
7878

79-
#
80-
# attributes onnx understands. Everything else coming from tensorflow
81-
# will be ignored.
82-
#
83-
ONNX_VALID_ATTRIBUTES = {
84-
'p', 'bias', 'axes', 'pads', 'mean', 'activation_beta', 'spatial_scale', 'broadcast', 'pooled_shape', 'high',
85-
'activation_alpha', 'is_test', 'hidden_size', 'activations', 'beta', 'input_as_shape', 'drop_states', 'alpha',
86-
'momentum', 'scale', 'axis', 'dilations', 'transB', 'axis_w', 'blocksize', 'output_sequence', 'mode', 'perm',
87-
'min', 'seed', 'ends', 'paddings', 'to', 'gamma', 'width_scale', 'normalize_variance', 'group', 'ratio', 'values',
88-
'dtype', 'output_shape', 'spatial', 'split', 'input_forget', 'keepdims', 'transA', 'auto_pad', 'border', 'low',
89-
'linear_before_reset', 'height_scale', 'output_padding', 'shape', 'kernel_shape', 'epsilon', 'size', 'starts',
90-
'direction', 'max', 'clip', 'across_channels', 'value', 'strides', 'extra_shape', 'scales', 'k', 'sample_size',
91-
'blocksize', 'epsilon', 'momentum', 'body', 'directions', 'num_scan_inputs', 'then_branch', 'else_branch'
92-
}
93-
9479
# index for internally generated names
9580
INTERNAL_NAME = 1
9681

0 commit comments

Comments
 (0)