Skip to content

Commit 21fc700

Browse files
authored
Merge pull request #394 from nbcsm/onnx
refine onnx node attributes check
2 parents 63f3703 + e413565 commit 21fc700

File tree

7 files changed

+110
-33
lines changed

7 files changed

+110
-33
lines changed

tests/test_graph.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import tf2onnx
2121
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2222
from tf2onnx.tfonnx import process_tf_graph
23-
from common import unittest_main
23+
from common import get_test_config, unittest_main
2424

2525
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
2626

@@ -96,6 +96,8 @@ def setUp(self):
9696
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
9797
tf.logging.set_verbosity(tf.logging.WARN)
9898

99+
self.config = get_test_config()
100+
99101
tf2onnx.utils.INTERNAL_NAME = 1
100102
tf.reset_default_graph()
101103
arg = namedtuple("Arg", "input inputs outputs verbose continue_on_error")
@@ -115,7 +117,7 @@ def test_abs(self):
115117
x = tf.placeholder(tf.float32, [2, 3], name="input")
116118
x_ = tf.abs(x)
117119
_ = tf.identity(x_, name="output")
118-
g = process_tf_graph(sess.graph)
120+
g = process_tf_graph(sess.graph, opset=self.config.opset)
119121
self.assertEqual('digraph { input [op_type=Placeholder shape="[2, 3]"]' \
120122
' Abs [op_type=Abs] output [op_type=Identity] input:0 -> Abs Abs:0 -> output }',
121123
onnx_to_graphviz(g))
@@ -127,7 +129,7 @@ def test_randomuniform(self):
127129
x_ = tf.identity(x_, name="output1")
128130
x_ = tf.identity(x_, name="output2")
129131
_ = tf.identity(x_, name="output")
130-
g = process_tf_graph(sess.graph)
132+
g = process_tf_graph(sess.graph, opset=self.config.opset)
131133
self.assertEqual(
132134
'digraph { RandomUniform__2 [op_type=RandomUniform shape="[2, 3]"] output1 [op_type=Identity] '
133135
'output2 [op_type=Identity] output [op_type=Identity] RandomUniform__2:0 -> output1 '
@@ -138,7 +140,7 @@ def test_randomnormal(self):
138140
with tf.Session() as sess:
139141
x_ = tf.random_normal([2, 3], name="rand")
140142
_ = tf.identity(x_, name="output")
141-
g = process_tf_graph(sess.graph)
143+
g = process_tf_graph(sess.graph, opset=self.config.opset)
142144
actual = onnx_to_graphviz(g)
143145
expected = 'digraph { RandomNormal__2 [op_type=RandomNormal shape="[2, 3]"] output [op_type=Identity] ' \
144146
'RandomNormal__2:0 -> output }'
@@ -154,7 +156,7 @@ def test_dropout(self):
154156
x_ = tf.identity(x_, name="output1")
155157
x_ = tf.identity(x_, name="output2")
156158
_ = tf.identity(x_, name="output")
157-
g = process_tf_graph(sess.graph)
159+
g = process_tf_graph(sess.graph, opset=self.config.opset)
158160
actual = onnx_to_graphviz(g)
159161
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
160162
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] Dropout__3 [op_type=Dropout] ' \
@@ -169,7 +171,7 @@ def test_add(self):
169171
x2 = tf.placeholder(tf.float32, [1, 3], name="input2")
170172
x_ = tf.add(x1, x2)
171173
_ = tf.identity(x_, name="output")
172-
g = process_tf_graph(sess.graph)
174+
g = process_tf_graph(sess.graph, opset=self.config.opset)
173175
self.assertEqual(
174176
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[2, 3]"] '
175177
'Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> Add Add:0 -> output }',
@@ -181,7 +183,7 @@ def test_squareddifference(self):
181183
x2 = tf.placeholder(tf.float32, [1, 3], name="input2")
182184
x_ = tf.squared_difference(x1, x2)
183185
_ = tf.identity(x_, name="output")
184-
g = process_tf_graph(sess.graph)
186+
g = process_tf_graph(sess.graph, opset=self.config.opset)
185187
self.assertEqual(
186188
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[1, 3]"] '
187189
'SquaredDifference [op_type=Sub] SquaredDifference__2 [op_type=Mul] '
@@ -195,7 +197,7 @@ def test_reducesum(self):
195197
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
196198
x_ = tf.reduce_sum(x1)
197199
_ = tf.identity(x_, name="output")
198-
g = process_tf_graph(sess.graph)
200+
g = process_tf_graph(sess.graph, opset=self.config.opset)
199201
self.assertEqual(
200202
'digraph { Const [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
201203
'Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }',
@@ -206,7 +208,7 @@ def test_argminmax(self):
206208
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
207209
x_ = tf.argmin(x1, axis=0)
208210
_ = tf.identity(x_, name="output")
209-
g = process_tf_graph(sess.graph)
211+
g = process_tf_graph(sess.graph, opset=self.config.opset)
210212
self.assertEqual(
211213
'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
212214
'ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }',
@@ -217,7 +219,7 @@ def test_rsqrt(self):
217219
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
218220
x_ = tf.rsqrt(x1)
219221
_ = tf.identity(x_, name="output")
220-
g = process_tf_graph(sess.graph)
222+
g = process_tf_graph(sess.graph, opset=self.config.opset)
221223
self.assertEqual(
222224
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Rsqrt [op_type=Sqrt] '
223225
'Rsqrt__2 [op_type=Reciprocal] output [op_type=Identity] input1:0 -> Rsqrt '
@@ -229,7 +231,7 @@ def test_relu6(self):
229231
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
230232
x_ = tf.nn.relu6(x1)
231233
_ = tf.identity(x_, name="output")
232-
g = process_tf_graph(sess.graph)
234+
g = process_tf_graph(sess.graph, opset=self.config.opset)
233235
self.assertEqual(
234236
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Relu] Relu6__2 [op_type=Clip] '
235237
'output [op_type=Identity] input1:0 -> Relu6 Relu6:0 -> Relu6__2 Relu6__2:0 -> output }',
@@ -257,7 +259,7 @@ def test_conv2d(self):
257259
sess.run(tf.global_variables_initializer())
258260
_ = sess.run(conv, feed_dict={image_: image})
259261

260-
g = process_tf_graph(sess.graph)
262+
g = process_tf_graph(sess.graph, opset=self.config.opset)
261263
self.assertEqual(
262264
'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__3 [op_type=Transpose] '
263265
'"kernel/shape" [op_type=Const] kernel__2 [op_type=Cast] k [op_type=Const] '
@@ -272,7 +274,7 @@ def test_squeeze(self):
272274
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
273275
x_ = tf.squeeze(x1)
274276
_ = tf.identity(x_, name="output")
275-
g = process_tf_graph(sess.graph)
277+
g = process_tf_graph(sess.graph, opset=self.config.opset)
276278
self.assertEqual(
277279
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Squeeze [op_type=Squeeze] '
278280
'output [op_type=Identity] input1:0 -> Squeeze Squeeze:0 -> output }',
@@ -283,7 +285,7 @@ def test_cast(self):
283285
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
284286
x_ = tf.cast(x1, tf.int32)
285287
_ = tf.identity(x_, name="output")
286-
g = process_tf_graph(sess.graph)
288+
g = process_tf_graph(sess.graph, opset=self.config.opset)
287289
self.assertEqual(
288290
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Cast [op_type=Cast] output [op_type=Identity] '
289291
'input1:0 -> Cast Cast:0 -> output }',
@@ -294,7 +296,7 @@ def test_reshape(self):
294296
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
295297
x_ = tf.reshape(x1, [3, 2])
296298
_ = tf.identity(x_, name="output")
297-
g = process_tf_graph(sess.graph)
299+
g = process_tf_graph(sess.graph, opset=self.config.opset)
298300
self.assertEqual(
299301
'digraph { "Reshape/shape" [op_type=Const] Reshape__2 [op_type=Cast] '
300302
'input1 [op_type=Placeholder shape="[2, 3]"] Reshape [op_type=Reshape] '
@@ -319,7 +321,7 @@ def rewrite_test(g, ops):
319321
x = tf.placeholder(tf.float32, [2, 3], name="input1")
320322
x_ = tf.add(x, x)
321323
_ = tf.identity(x_, name="output")
322-
g = process_tf_graph(sess.graph, custom_rewriter=[rewrite_test])
324+
g = process_tf_graph(sess.graph, opset=self.config.opset, custom_rewriter=[rewrite_test])
323325
self.assertEqual(
324326
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Mul] '
325327
'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }',
@@ -345,6 +347,7 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
345347
_ = tf.identity(x_, name="output")
346348
g = process_tf_graph(sess.graph,
347349
custom_op_handlers={"Print": (print_handler, ["Identity", "mode"])},
350+
opset=self.config.opset,
348351
extra_opset=helper.make_opsetid(_TENSORFLOW_DOMAIN, 1))
349352
self.assertEqual(
350353
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '

tests/test_internals.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,35 @@ def test_data_format(self):
218218
self.assertEqual(n.data_format, "NHWC")
219219
self.assertTrue(n.is_nhwc())
220220

221+
def test_node_attr_onnx(self):
222+
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", my_attr="my_attr")
223+
graph_proto = helper.make_graph(
224+
nodes=[n1],
225+
name="test",
226+
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
227+
helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])],
228+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])],
229+
initializer=[]
230+
)
231+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
232+
n1 = g.get_node_by_name("n1")
233+
self.assertTrue("my_attr" in n1.attr)
234+
self.assertTrue("my_attr" not in n1.attr_onnx)
235+
236+
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", domain="my_domain", my_attr="my_attr")
237+
graph_proto = helper.make_graph(
238+
nodes=[n1],
239+
name="test",
240+
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
241+
helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])],
242+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])],
243+
initializer=[]
244+
)
245+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
246+
n1 = g.get_node_by_name("n1")
247+
self.assertTrue("my_attr" in n1.attr)
248+
self.assertTrue("my_attr" in n1.attr_onnx)
249+
221250

222251
if __name__ == '__main__':
223252
unittest_main()

tf2onnx/convert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from __future__ import unicode_literals
1111

1212
import argparse
13-
import onnx
1413
from onnx import helper
1514
import tensorflow as tf
1615

@@ -80,7 +79,7 @@ def main():
8079

8180
opset = utils.find_opset(args.opset)
8281
print("using tensorflow={}, onnx={}, opset={}, tfonnx={}/{}".format(
83-
tf.__version__, onnx.__version__, opset,
82+
tf.__version__, utils.get_onnx_version(), opset,
8483
tf2onnx.__version__, tf2onnx.version.git_version[:6]))
8584

8685
# override unknown dimensions from -1 to 1 (aka batchsize 1) since not every runtime does

tf2onnx/graph.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import collections
1313
import copy
14+
import logging
1415
import sys
1516
import traceback
1617
import six
@@ -22,6 +23,9 @@
2223
from tf2onnx.optimizer import IdentityOptimizer, TransposeOptimizer
2324
from tf2onnx.schemas import get_schema
2425

26+
logging.basicConfig(level=logging.INFO)
27+
log = logging.getLogger("graph")
28+
2529

2630
# todo(pengwa): remove protected-access later
2731
# pylint: disable=broad-except,protected-access
@@ -83,12 +87,16 @@ def attr(self):
8387

8488
@property
8589
def attr_onnx(self):
90+
"""Return onnx valid attributes"""
91+
schema = get_schema(self.type, self.graph.opset, self.domain)
92+
if schema is None and not (self.is_const() or self.is_graph_input()):
93+
log.warning("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check", self.name, self.domain,
94+
self.type)
95+
8696
onnx_attrs = {}
8797
for a in self._attr.values():
88-
schema = get_schema(self.type, self.graph.opset)
89-
if schema:
90-
if schema.has_attribute(a.name):
91-
onnx_attrs[a.name] = a
98+
if schema is None or schema.has_attribute(a.name):
99+
onnx_attrs[a.name] = a
92100
return onnx_attrs
93101

94102
@property
@@ -370,8 +378,10 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
370378
self.remove_node(n.name)
371379

372380
new_outputs = [o if o != output else new_output_name for output in n.output]
381+
# domain should be passed to new node
373382
new_node = self.make_node(n.type, n.input, outputs=new_outputs, attr=n.attr, name=n.name,
374-
skip_conversion=n._skip_conversion, dtypes=n_dtypes, shapes=n_shapes)
383+
skip_conversion=n._skip_conversion, dtypes=n_dtypes, shapes=n_shapes,
384+
domain=n.domain)
375385

376386
if body_graphs:
377387
for attr_name, body_graph in body_graphs.items():
@@ -416,7 +426,7 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
416426
return node
417427

418428
def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, skip_conversion=True,
419-
op_name_scope=None, name=None, shapes=None, dtypes=None):
429+
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=None):
420430
"""Make a new onnx node in the graph"""
421431
if attr is None:
422432
attr = {}
@@ -449,7 +459,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
449459
n = self.get_node_by_output_in_current_graph(o)
450460
utils.make_sure(n is None, "output tensor named %s already exists in node: \n%s", o, n)
451461

452-
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, **raw_attr)
462+
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)
453463

454464
if op_type in ["If", "Loop", "Scan"]:
455465
# we force the op containing inner graphs not skipped during conversion.
@@ -876,7 +886,7 @@ def remove_input(node, to_be_removed):
876886
# don't remove output from parent since others might depend on it
877887
return True
878888

879-
def insert_new_node_on_input(self, node, op_type, input_name, name=None, **kwargs):
889+
def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=None, **kwargs):
880890
"""Create and insert a new node into the graph.
881891
Args:
882892
node: we want to replace the input for this node
@@ -891,14 +901,14 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, **kwarg
891901
if name is None:
892902
name = utils.make_name(node.name)
893903
new_output = port_name(name)
894-
new_node = self.make_node(op_type, [input_name], attr=kwargs, outputs=[new_output], name=name)
904+
new_node = self.make_node(op_type, [input_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
895905
for i, n in enumerate(node.input):
896906
if n == input_name:
897907
node.input[i] = new_output
898908
break
899909
return new_node
900910

901-
def insert_new_node_on_output(self, op_type, output_name, name=None, **kwargs):
911+
def insert_new_node_on_output(self, op_type, output_name, name=None, domain=None, **kwargs):
902912
"""Create and insert a new node into the graph.
903913
Args:
904914
op_type: type for new operation
@@ -915,7 +925,7 @@ def insert_new_node_on_output(self, op_type, output_name, name=None, **kwargs):
915925
type(op_type))
916926

917927
new_output = port_name(name)
918-
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name)
928+
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
919929

920930
to_replace = [n for n in self.get_nodes() if n != new_node]
921931
self.replace_all_inputs(to_replace, output_name, new_output)
@@ -1054,7 +1064,7 @@ def optimize_graph(graph, doc_string, optimize=None, debug=False):
10541064
try:
10551065
opts = [TransposeOptimizer(graph, output_names=graph.outputs, debug=debug),
10561066
IdentityOptimizer(graph, output_names=graph.outputs, debug=debug)
1057-
]
1067+
]
10581068
for opt in opts:
10591069
opt.optimize()
10601070
model_proto = graph.make_model(doc_string, optimize=optimize)
@@ -1080,7 +1090,7 @@ def optimize_graph_with_model_proto(onnx_model_proto, debug=False):
10801090

10811091
opts = [TransposeOptimizer(g, output_names=g.outputs, debug=debug),
10821092
IdentityOptimizer(g, output_names=g.outputs, debug=debug)
1083-
]
1093+
]
10841094
for opt in opts:
10851095
opt.optimize()
10861096

tf2onnx/schemas.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from collections import defaultdict, OrderedDict
1313
from onnx import defs
1414

15-
1615
ONNX_DOMAIN = ""
1716

1817

@@ -77,16 +76,39 @@ def _register_all_schemas_with_history():
7776
return ordered_map
7877

7978

79+
def _parse_domain_opset_versions(schemas):
80+
""" Get max opset version among all schemas within each domain. """
81+
domain_opset_versions = dict()
82+
for domain_version_schema_map in schemas.values():
83+
for domain, version_schema_map in domain_version_schema_map.items():
84+
# version_schema_map is sorted by since_version in descend order
85+
max_version = next(iter(version_schema_map))
86+
if domain not in domain_opset_versions:
87+
domain_opset_versions[domain] = int(max_version)
88+
else:
89+
domain_opset_versions[domain] = max(domain_opset_versions[domain], int(max_version))
90+
return domain_opset_versions
91+
92+
8093
# format is <OpName, <Domain, <SinceVersion, OpSchema>>>
8194
# SinceVersion is sorted from high to low
8295
_schemas = _register_all_schemas_with_history()
8396

97+
_domain_opset_versions = _parse_domain_opset_versions(_schemas)
98+
8499

85100
def get_schema(name, max_inclusive_opset_version, domain=ONNX_DOMAIN):
86101
"""Get schema by name within specific version."""
102+
domain = domain or ONNX_DOMAIN
87103
domain_version_schema_map = _schemas[name]
88104
version_schema_map = domain_version_schema_map[domain]
89105
for version, schema in version_schema_map.items():
90106
if version <= max_inclusive_opset_version:
91107
return schema
92108
return None
109+
110+
111+
def get_max_supported_opset_version(domain=ONNX_DOMAIN):
112+
"""Get max supported opset version by current onnx package given a domain."""
113+
domain = domain or ONNX_DOMAIN
114+
return _domain_opset_versions.get(domain, None)

0 commit comments

Comments
 (0)