Skip to content

Commit 75a7f24

Browse files
author
wayuanho
committed
fetch valid attributes from ONNX dynamically
1 parent 0b4ce93 commit 75a7f24

File tree

3 files changed

+55
-16
lines changed

3 files changed

+55
-16
lines changed

tf2onnx/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tf2onnx import utils, __version__
2020
from tf2onnx.utils import port_name, find_opset
2121
from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
22+
from tf2onnx.schema import ONNXSchema
2223

2324

2425
# todo(pengwa): remove protected-access later
@@ -72,7 +73,7 @@ def attr(self):
7273
def attr_onnx(self):
7374
onnx_attrs = {}
7475
for a in self._attr.values():
75-
if a.name in utils.ONNX_VALID_ATTRIBUTES:
76+
if a.name in ONNXSchema.get_attribute(self.type, self.graph.opset):
7677
onnx_attrs[a.name] = a
7778
return onnx_attrs
7879

tf2onnx/schema.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.schema
6+
"""
7+
8+
from __future__ import division
9+
from __future__ import print_function
10+
from __future__ import unicode_literals
11+
12+
from collections import defaultdict, OrderedDict
13+
from onnx import defs
14+
15+
def version_mapping(schemas):
16+
"""Obtain version to schema mapping."""
17+
schemas_with_version = defaultdict(dict)
18+
for s in schemas:
19+
schemas_with_version[s.name][int(s.since_version)] = s
20+
for name, s in schemas_with_version.items():
21+
schemas_with_version[name] = OrderedDict(
22+
sorted(schemas_with_version[name].items(), key=lambda x: x[0])
23+
)
24+
return schemas_with_version
25+
26+
27+
class ONNXSchema:
28+
"""Wrapper for ONNX schema"""
29+
30+
all_schemas = defs.get_all_schemas_with_history()
31+
schemas_with_version = version_mapping(all_schemas)
32+
33+
@staticmethod
34+
def get_schema(name, version):
35+
"""Get schema by name within specific version."""
36+
if name not in ONNXSchema.schemas_with_version:
37+
return None
38+
if version < 1:
39+
return None
40+
schemas = ONNXSchema.schemas_with_version[name]
41+
versions = list(schemas.keys())
42+
for i, v in enumerate(versions):
43+
if version < v:
44+
return schemas[versions[i-1]]
45+
return schemas[versions[-1]]
46+
47+
@staticmethod
48+
def get_attribute(name, version):
49+
"""Get valid attributes by op's name and specific version"""
50+
schema = ONNXSchema.get_schema(name, version)
51+
if not schema:
52+
return {}
53+
return schema.attributes

tf2onnx/utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,6 @@
7676

7777
ONNX_UNKNOWN_DIMENSION = -1
7878

79-
#
80-
# attributes onnx understands. Everything else coming from tensorflow
81-
# will be ignored.
82-
#
83-
ONNX_VALID_ATTRIBUTES = {
84-
'p', 'bias', 'axes', 'pads', 'mean', 'activation_beta', 'spatial_scale', 'broadcast', 'pooled_shape', 'high',
85-
'activation_alpha', 'is_test', 'hidden_size', 'activations', 'beta', 'input_as_shape', 'drop_states', 'alpha',
86-
'momentum', 'scale', 'axis', 'dilations', 'transB', 'axis_w', 'blocksize', 'output_sequence', 'mode', 'perm',
87-
'min', 'seed', 'ends', 'paddings', 'to', 'gamma', 'width_scale', 'normalize_variance', 'group', 'ratio', 'values',
88-
'dtype', 'output_shape', 'spatial', 'split', 'input_forget', 'keepdims', 'transA', 'auto_pad', 'border', 'low',
89-
'linear_before_reset', 'height_scale', 'output_padding', 'shape', 'kernel_shape', 'epsilon', 'size', 'starts',
90-
'direction', 'max', 'clip', 'across_channels', 'value', 'strides', 'extra_shape', 'scales', 'k', 'sample_size',
91-
'blocksize', 'epsilon', 'momentum', 'body', 'directions', 'num_scan_inputs', 'then_branch', 'else_branch'
92-
}
93-
9479
# index for internally generated names
9580
INTERNAL_NAME = 1
9681

0 commit comments

Comments
 (0)