Skip to content

Commit 4354a03

Browse files
committed
add onnx package version check, skip onnx attr check for unknown op
1 parent a2b8085 commit 4354a03

File tree

5 files changed

+52
-9
lines changed

5 files changed

+52
-9
lines changed

tf2onnx/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def main():
8080

8181
opset = utils.find_opset(args.opset)
8282
print("using tensorflow={}, onnx={}, opset={}, tfonnx={}/{}".format(
83-
tf.__version__, onnx.__version__, opset,
83+
tf.__version__, utils.get_onnx_version(), opset,
8484
tf2onnx.__version__, tf2onnx.version.git_version[:6]))
8585

8686
# override unknown dimensions from -1 to 1 (aka batchsize 1) since not every runtime does

tf2onnx/graph.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import collections
1313
import copy
14+
import logging
1415
import sys
1516
import traceback
1617
import six
@@ -22,6 +23,9 @@
2223
from tf2onnx.optimizer import IdentityOptimizer, TransposeOptimizer
2324
from tf2onnx.schemas import get_schema
2425

26+
logging.basicConfig(level=logging.INFO)
27+
log = logging.getLogger("graph")
28+
2529

2630
# todo(pengwa): remove protected-access later
2731
# pylint: disable=broad-except,protected-access
@@ -83,12 +87,15 @@ def attr(self):
8387

8488
@property
8589
def attr_onnx(self):
90+
schema = get_schema(self.type, self.graph.opset, self.domain)
91+
if schema is None and not (self.is_const() or self.is_graph_input()):
92+
log.warning("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check", self.name, self.domain,
93+
self.type)
94+
8695
onnx_attrs = {}
8796
for a in self._attr.values():
88-
schema = get_schema(self.type, self.graph.opset)
89-
if schema:
90-
if schema.has_attribute(a.name):
91-
onnx_attrs[a.name] = a
97+
if schema is None or schema.has_attribute(a.name):
98+
onnx_attrs[a.name] = a
9299
return onnx_attrs
93100

94101
@property
@@ -1054,7 +1061,7 @@ def optimize_graph(graph, doc_string, optimize=None, debug=False):
10541061
try:
10551062
opts = [TransposeOptimizer(graph, output_names=graph.outputs, debug=debug),
10561063
IdentityOptimizer(graph, output_names=graph.outputs, debug=debug)
1057-
]
1064+
]
10581065
for opt in opts:
10591066
opt.optimize()
10601067
model_proto = graph.make_model(doc_string, optimize=optimize)
@@ -1080,7 +1087,7 @@ def optimize_graph_with_model_proto(onnx_model_proto, debug=False):
10801087

10811088
opts = [TransposeOptimizer(g, output_names=g.outputs, debug=debug),
10821089
IdentityOptimizer(g, output_names=g.outputs, debug=debug)
1083-
]
1090+
]
10841091
for opt in opts:
10851092
opt.optimize()
10861093

tf2onnx/schemas.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from collections import defaultdict, OrderedDict
1313
from onnx import defs
1414

15-
1615
ONNX_DOMAIN = ""
1716

1817

@@ -77,16 +76,39 @@ def _register_all_schemas_with_history():
7776
return ordered_map
7877

7978

79+
def _parse_domain_opset_versions(schemas):
80+
""" Get max opset version among all schemas within each domain. """
81+
domain_opset_versions = dict()
82+
for domain_version_schema_map in schemas.values():
83+
for domain, version_schema_map in domain_version_schema_map.items():
84+
# version_schema_map is sorted by since_version in descend order
85+
max_version = next(iter(version_schema_map))
86+
if domain not in domain_opset_versions:
87+
domain_opset_versions[domain] = int(max_version)
88+
else:
89+
domain_opset_versions[domain] = max(domain_opset_versions[domain], int(max_version))
90+
return domain_opset_versions
91+
92+
8093
# format is <OpName, <Domain, <SinceVersion, OpSchema>>>
8194
# SinceVersion is sorted from high to low
8295
_schemas = _register_all_schemas_with_history()
8396

97+
_domain_opset_versions = _parse_domain_opset_versions(_schemas)
98+
8499

85100
def get_schema(name, max_inclusive_opset_version, domain=ONNX_DOMAIN):
86101
"""Get schema by name within specific version."""
102+
domain = domain or ONNX_DOMAIN
87103
domain_version_schema_map = _schemas[name]
88104
version_schema_map = domain_version_schema_map[domain]
89105
for version, schema in version_schema_map.items():
90106
if version <= max_inclusive_opset_version:
91107
return schema
92108
return None
109+
110+
111+
def get_max_supported_opset_version(domain=ONNX_DOMAIN):
112+
"""Get max supported opset version by current onnx package given a domain."""
113+
domain = domain or ONNX_DOMAIN
114+
return _domain_opset_versions.get(domain, None)

tf2onnx/tfonnx.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tensorflow.tools.graph_transforms import TransformGraph
2121

2222
import tf2onnx
23-
from tf2onnx import utils
23+
from tf2onnx import schemas, utils
2424
from tf2onnx.function import * # pylint: disable=wildcard-import
2525
from tf2onnx.graph import Graph
2626
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
@@ -1461,6 +1461,7 @@ def reverse_op8(ctx, node, name, args):
14611461
node.input[0] = node.input[1]
14621462
node.input[1] = tmp
14631463

1464+
14641465
def reverse_op9(ctx, node, name, args):
14651466
# T output = ReverseSequence(T input, int32|int64 seq_lengths, @int seq_dim, @int batch_dim)
14661467
# we cannot easily construct reverse_sequence equivalence in opset 9, so we will not support it
@@ -1649,6 +1650,7 @@ def logical_compare_op(ctx, node, name, args):
16491650
ctx.copy_shape(inp, inp_cast.output[0])
16501651
ctx.set_dtype(inp_cast.output[0], target_dtype)
16511652

1653+
16521654
def logical_compareeq_op(ctx, node, name, args):
16531655
logical_compare_op(ctx, node, name, [])
16541656
output_name = node.output[0]
@@ -2400,6 +2402,11 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
24002402
Return:
24012403
onnx graph
24022404
"""
2405+
if opset > schemas.get_max_supported_opset_version():
2406+
log.warning("currently installed onnx package %s is too low to support opset %s, "
2407+
"please upgrade onnx package to avoid potential conversion issue.",
2408+
utils.get_onnx_version(), opset)
2409+
24032410
if shape_override is None:
24042411
shape_override = {}
24052412
if inputs_as_nchw is None:

tf2onnx/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import tensorflow as tf
1919
from tensorflow.core.framework import types_pb2, tensor_pb2
2020
from google.protobuf import text_format
21+
import onnx
2122
from onnx import helper, onnx_pb, defs, numpy_helper
2223

2324
#
@@ -74,13 +75,15 @@
7475
onnx_pb.TensorProto.BOOL: "bool"
7576
}
7677

78+
7779
class TensorValueInfo(object):
7880
def __init__(self, tensor_id, g):
7981
self.id = tensor_id
8082
if self.id:
8183
self.dtype = g.get_dtype(tensor_id)
8284
self.shape = g.get_shape(tensor_id)
8385

86+
8487
ONNX_UNKNOWN_DIMENSION = -1
8588

8689
# index for internally generated names
@@ -409,3 +412,7 @@ def are_shapes_equal(src, dest):
409412
def create_vague_shape_like(shape):
410413
make_sure(len(shape) >= 0, "rank should be >= 0")
411414
return [-1 for i in enumerate(shape)]
415+
416+
417+
def get_onnx_version():
418+
return onnx.__version__

0 commit comments

Comments
 (0)