Skip to content

Commit 2ace1ca

Browse files
author
wayuanho
committed
resolve requests
1 parent 75a7f24 commit 2ace1ca

File tree

5 files changed

+100
-59
lines changed

5 files changed

+100
-59
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.coverage
1+
.coverage*
22
*.pyc
33
.idea
44
build

tf2onnx/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import unicode_literals
88

99

10-
__all__ = ["utils", "graph_matcher", "graph", "tfonnx", "shape_inference"]
10+
__all__ = ["utils", "graph_matcher", "graph", "tfonnx", "shape_inference", "schemas"]
1111

1212
from .version import version as __version__
13-
from tf2onnx import tfonnx, utils, graph, graph_matcher, shape_inference # pylint: disable=wrong-import-order
13+
from tf2onnx import tfonnx, utils, graph, graph_matcher, shape_inference, schemas # pylint: disable=wrong-import-order

tf2onnx/graph.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tf2onnx import utils, __version__
2020
from tf2onnx.utils import port_name, find_opset
2121
from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
22-
from tf2onnx.schema import ONNXSchema
22+
from tf2onnx.schemas import get_schema
2323

2424

2525
# todo(pengwa): remove protected-access later
@@ -73,8 +73,10 @@ def attr(self):
7373
def attr_onnx(self):
7474
onnx_attrs = {}
7575
for a in self._attr.values():
76-
if a.name in ONNXSchema.get_attribute(self.type, self.graph.opset):
77-
onnx_attrs[a.name] = a
76+
schema = get_schema(self.type, self.graph.opset)
77+
if schema:
78+
if schema.has_attribute(a.name):
79+
onnx_attrs[a.name] = a
7880
return onnx_attrs
7981

8082
@property

tf2onnx/schema.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

tf2onnx/schemas.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.schema
6+
"""
7+
8+
from __future__ import division
9+
from __future__ import print_function
10+
from __future__ import unicode_literals
11+
12+
from collections import defaultdict, OrderedDict
13+
from onnx import defs
14+
15+
16+
ONNX_DOMAIN = ""
17+
18+
19+
class OnnxOpSchema(object):
20+
"""Wrapper for Onnx schema."""
21+
22+
def __init__(self, name, domain, since_version, attributes):
23+
"""Create a Onnx schema
24+
Args:
25+
name (str): op name
26+
attributes (List[str]): valid attributes
27+
domain (str): default value "" means it's Onnx domain
28+
since_version (int): opset version, default is 1
29+
"""
30+
self._name = name
31+
self._domain = domain
32+
self._attributes = attributes
33+
self._since_version = since_version
34+
35+
@property
36+
def attributes(self):
37+
return self._attributes
38+
39+
@property
40+
def domain(self):
41+
return self._domain
42+
43+
@property
44+
def name(self):
45+
return self._name
46+
47+
@property
48+
def since_version(self):
49+
return self._since_version
50+
51+
@staticmethod
52+
def from_onnx_schema(onnx_schema):
53+
name = onnx_schema.name
54+
domain = onnx_schema.domain
55+
since_version = int(onnx_schema.since_version)
56+
attributes = onnx_schema.attributes
57+
return OnnxOpSchema(name, domain, since_version, attributes)
58+
59+
def has_attribute(self, attr):
60+
return attr in self.attributes
61+
62+
63+
def _register_all_schemas_with_history():
64+
"""Register all schemas with history"""
65+
onnx_schemas = defs.get_all_schemas_with_history()
66+
name_domain_version_schema_map = defaultdict(lambda: defaultdict(dict))
67+
for s in onnx_schemas:
68+
schema = OnnxOpSchema.from_onnx_schema(s)
69+
name_domain_version_schema_map[schema.name][schema.domain][schema.since_version] = schema
70+
71+
ordered_map = defaultdict(lambda: defaultdict(OrderedDict))
72+
for name, domain_version_schema_map in name_domain_version_schema_map.items():
73+
for domain, version_schema_map in domain_version_schema_map.items():
74+
ordered_map[name][domain] = OrderedDict(
75+
sorted(version_schema_map.items(), key=lambda x: -x[0])
76+
)
77+
return ordered_map
78+
79+
80+
# format is <OpName, <Domain, <SinceVersion, OpSchema>>>
81+
# SinceVersion is sorted from high to low
82+
_schemas = _register_all_schemas_with_history()
83+
84+
85+
def get_schema(name, max_inclusive_opset_version, domain=ONNX_DOMAIN):
86+
"""Get schema by name within specific version."""
87+
domain_version_schema_map = _schemas[name]
88+
version_schema_map = domain_version_schema_map[domain]
89+
for version, schema in version_schema_map.items():
90+
if version <= max_inclusive_opset_version:
91+
return schema
92+
return None

0 commit comments

Comments
 (0)