Skip to content

Commit 4b0b190

Browse files
committed
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into fix_issue_
2 parents 119c7b0 + 21fc700 commit 4b0b190

File tree

8 files changed

+111
-34
lines changed

8 files changed

+111
-34
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,5 @@ def run(self):
8080
8181
author_email='[email protected]',
8282
url='https://github.com/onnx/tensorflow-onnx',
83-
install_requires=['numpy>=1.14.1', 'onnx>=1.2.2', 'six']
83+
install_requires=['numpy>=1.14.1', 'onnx>=1.4.1', 'six']
8484
)

tests/test_graph.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import tf2onnx
2121
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2222
from tf2onnx.tfonnx import process_tf_graph
23-
from common import unittest_main
23+
from common import get_test_config, unittest_main
2424

2525
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
2626

@@ -96,6 +96,8 @@ def setUp(self):
9696
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
9797
tf.logging.set_verbosity(tf.logging.WARN)
9898

99+
self.config = get_test_config()
100+
99101
tf2onnx.utils.INTERNAL_NAME = 1
100102
tf.reset_default_graph()
101103
arg = namedtuple("Arg", "input inputs outputs verbose continue_on_error")
@@ -115,7 +117,7 @@ def test_abs(self):
115117
x = tf.placeholder(tf.float32, [2, 3], name="input")
116118
x_ = tf.abs(x)
117119
_ = tf.identity(x_, name="output")
118-
g = process_tf_graph(sess.graph)
120+
g = process_tf_graph(sess.graph, opset=self.config.opset)
119121
self.assertEqual('digraph { input [op_type=Placeholder shape="[2, 3]"]' \
120122
' Abs [op_type=Abs] output [op_type=Identity] input:0 -> Abs Abs:0 -> output }',
121123
onnx_to_graphviz(g))
@@ -127,7 +129,7 @@ def test_randomuniform(self):
127129
x_ = tf.identity(x_, name="output1")
128130
x_ = tf.identity(x_, name="output2")
129131
_ = tf.identity(x_, name="output")
130-
g = process_tf_graph(sess.graph)
132+
g = process_tf_graph(sess.graph, opset=self.config.opset)
131133
self.assertEqual(
132134
'digraph { RandomUniform__2 [op_type=RandomUniform shape="[2, 3]"] output1 [op_type=Identity] '
133135
'output2 [op_type=Identity] output [op_type=Identity] RandomUniform__2:0 -> output1 '
@@ -138,7 +140,7 @@ def test_randomnormal(self):
138140
with tf.Session() as sess:
139141
x_ = tf.random_normal([2, 3], name="rand")
140142
_ = tf.identity(x_, name="output")
141-
g = process_tf_graph(sess.graph)
143+
g = process_tf_graph(sess.graph, opset=self.config.opset)
142144
actual = onnx_to_graphviz(g)
143145
expected = 'digraph { RandomNormal__2 [op_type=RandomNormal shape="[2, 3]"] output [op_type=Identity] ' \
144146
'RandomNormal__2:0 -> output }'
@@ -154,7 +156,7 @@ def test_dropout(self):
154156
x_ = tf.identity(x_, name="output1")
155157
x_ = tf.identity(x_, name="output2")
156158
_ = tf.identity(x_, name="output")
157-
g = process_tf_graph(sess.graph)
159+
g = process_tf_graph(sess.graph, opset=self.config.opset)
158160
actual = onnx_to_graphviz(g)
159161
expected = 'digraph { prob [op_type=Placeholder shape="[]"] input2 [op_type=Placeholder shape="[1, 3]"] ' \
160162
'input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Add] Dropout__3 [op_type=Dropout] ' \
@@ -169,7 +171,7 @@ def test_add(self):
169171
x2 = tf.placeholder(tf.float32, [1, 3], name="input2")
170172
x_ = tf.add(x1, x2)
171173
_ = tf.identity(x_, name="output")
172-
g = process_tf_graph(sess.graph)
174+
g = process_tf_graph(sess.graph, opset=self.config.opset)
173175
self.assertEqual(
174176
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[2, 3]"] '
175177
'Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> Add Add:0 -> output }',
@@ -181,7 +183,7 @@ def test_squareddifference(self):
181183
x2 = tf.placeholder(tf.float32, [1, 3], name="input2")
182184
x_ = tf.squared_difference(x1, x2)
183185
_ = tf.identity(x_, name="output")
184-
g = process_tf_graph(sess.graph)
186+
g = process_tf_graph(sess.graph, opset=self.config.opset)
185187
self.assertEqual(
186188
'digraph { input2 [op_type=Placeholder shape="[1, 3]"] input1 [op_type=Placeholder shape="[1, 3]"] '
187189
'SquaredDifference [op_type=Sub] SquaredDifference__2 [op_type=Mul] '
@@ -195,7 +197,7 @@ def test_reducesum(self):
195197
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
196198
x_ = tf.reduce_sum(x1)
197199
_ = tf.identity(x_, name="output")
198-
g = process_tf_graph(sess.graph)
200+
g = process_tf_graph(sess.graph, opset=self.config.opset)
199201
self.assertEqual(
200202
'digraph { Const [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
201203
'Sum [op_type=ReduceSum] output [op_type=Identity] input1:0 -> Sum Sum:0 -> output }',
@@ -206,7 +208,7 @@ def test_argminmax(self):
206208
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
207209
x_ = tf.argmin(x1, axis=0)
208210
_ = tf.identity(x_, name="output")
209-
g = process_tf_graph(sess.graph)
211+
g = process_tf_graph(sess.graph, opset=self.config.opset)
210212
self.assertEqual(
211213
'digraph { "ArgMin/dimension" [op_type=Const] input1 [op_type=Placeholder shape="[2, 3]"] '
212214
'ArgMin [op_type=ArgMin] output [op_type=Identity] input1:0 -> ArgMin ArgMin:0 -> output }',
@@ -217,7 +219,7 @@ def test_rsqrt(self):
217219
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
218220
x_ = tf.rsqrt(x1)
219221
_ = tf.identity(x_, name="output")
220-
g = process_tf_graph(sess.graph)
222+
g = process_tf_graph(sess.graph, opset=self.config.opset)
221223
self.assertEqual(
222224
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Rsqrt [op_type=Sqrt] '
223225
'Rsqrt__2 [op_type=Reciprocal] output [op_type=Identity] input1:0 -> Rsqrt '
@@ -229,7 +231,7 @@ def test_relu6(self):
229231
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
230232
x_ = tf.nn.relu6(x1)
231233
_ = tf.identity(x_, name="output")
232-
g = process_tf_graph(sess.graph)
234+
g = process_tf_graph(sess.graph, opset=self.config.opset)
233235
self.assertEqual(
234236
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Relu6 [op_type=Relu] Relu6__2 [op_type=Clip] '
235237
'output [op_type=Identity] input1:0 -> Relu6 Relu6:0 -> Relu6__2 Relu6__2:0 -> output }',
@@ -257,7 +259,7 @@ def test_conv2d(self):
257259
sess.run(tf.global_variables_initializer())
258260
_ = sess.run(conv, feed_dict={image_: image})
259261

260-
g = process_tf_graph(sess.graph)
262+
g = process_tf_graph(sess.graph, opset=self.config.opset)
261263
self.assertEqual(
262264
'digraph { input1 [op_type=Placeholder shape="[1, 4, 4, 1]"] Conv2D__3 [op_type=Transpose] '
263265
'"kernel/shape" [op_type=Const] kernel__2 [op_type=Cast] k [op_type=Const] '
@@ -272,7 +274,7 @@ def test_squeeze(self):
272274
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
273275
x_ = tf.squeeze(x1)
274276
_ = tf.identity(x_, name="output")
275-
g = process_tf_graph(sess.graph)
277+
g = process_tf_graph(sess.graph, opset=self.config.opset)
276278
self.assertEqual(
277279
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Squeeze [op_type=Squeeze] '
278280
'output [op_type=Identity] input1:0 -> Squeeze Squeeze:0 -> output }',
@@ -283,7 +285,7 @@ def test_cast(self):
283285
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
284286
x_ = tf.cast(x1, tf.int32)
285287
_ = tf.identity(x_, name="output")
286-
g = process_tf_graph(sess.graph)
288+
g = process_tf_graph(sess.graph, opset=self.config.opset)
287289
self.assertEqual(
288290
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Cast [op_type=Cast] output [op_type=Identity] '
289291
'input1:0 -> Cast Cast:0 -> output }',
@@ -294,7 +296,7 @@ def test_reshape(self):
294296
x1 = tf.placeholder(tf.float32, [2, 3], name="input1")
295297
x_ = tf.reshape(x1, [3, 2])
296298
_ = tf.identity(x_, name="output")
297-
g = process_tf_graph(sess.graph)
299+
g = process_tf_graph(sess.graph, opset=self.config.opset)
298300
self.assertEqual(
299301
'digraph { "Reshape/shape" [op_type=Const] Reshape__2 [op_type=Cast] '
300302
'input1 [op_type=Placeholder shape="[2, 3]"] Reshape [op_type=Reshape] '
@@ -319,7 +321,7 @@ def rewrite_test(g, ops):
319321
x = tf.placeholder(tf.float32, [2, 3], name="input1")
320322
x_ = tf.add(x, x)
321323
_ = tf.identity(x_, name="output")
322-
g = process_tf_graph(sess.graph, custom_rewriter=[rewrite_test])
324+
g = process_tf_graph(sess.graph, opset=self.config.opset, custom_rewriter=[rewrite_test])
323325
self.assertEqual(
324326
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Add [op_type=Mul] '
325327
'output [op_type=Identity] input1:0 -> Add input1:0 -> Add Add:0 -> output }',
@@ -345,6 +347,7 @@ def print_handler(ctx, node, name, args): # pylint: disable=unused-argument
345347
_ = tf.identity(x_, name="output")
346348
g = process_tf_graph(sess.graph,
347349
custom_op_handlers={"Print": (print_handler, ["Identity", "mode"])},
350+
opset=self.config.opset,
348351
extra_opset=helper.make_opsetid(_TENSORFLOW_DOMAIN, 1))
349352
self.assertEqual(
350353
'digraph { input1 [op_type=Placeholder shape="[2, 3]"] Print [op_type=Identity] '

tests/test_internals.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,35 @@ def test_data_format(self):
218218
self.assertEqual(n.data_format, "NHWC")
219219
self.assertTrue(n.is_nhwc())
220220

221+
def test_node_attr_onnx(self):
222+
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", my_attr="my_attr")
223+
graph_proto = helper.make_graph(
224+
nodes=[n1],
225+
name="test",
226+
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
227+
helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])],
228+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])],
229+
initializer=[]
230+
)
231+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
232+
n1 = g.get_node_by_name("n1")
233+
self.assertTrue("my_attr" in n1.attr)
234+
self.assertTrue("my_attr" not in n1.attr_onnx)
235+
236+
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", domain="my_domain", my_attr="my_attr")
237+
graph_proto = helper.make_graph(
238+
nodes=[n1],
239+
name="test",
240+
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
241+
helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])],
242+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])],
243+
initializer=[]
244+
)
245+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
246+
n1 = g.get_node_by_name("n1")
247+
self.assertTrue("my_attr" in n1.attr)
248+
self.assertTrue("my_attr" in n1.attr_onnx)
249+
221250

222251
if __name__ == '__main__':
223252
unittest_main()

tf2onnx/convert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from __future__ import unicode_literals
1111

1212
import argparse
13-
import onnx
1413
from onnx import helper
1514
import tensorflow as tf
1615

@@ -80,7 +79,7 @@ def main():
8079

8180
opset = utils.find_opset(args.opset)
8281
print("using tensorflow={}, onnx={}, opset={}, tfonnx={}/{}".format(
83-
tf.__version__, onnx.__version__, opset,
82+
tf.__version__, utils.get_onnx_version(), opset,
8483
tf2onnx.__version__, tf2onnx.version.git_version[:6]))
8584

8685
# override unknown dimensions from -1 to 1 (aka batchsize 1) since not every runtime does

tf2onnx/graph.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import collections
1313
import copy
14+
import logging
1415
import sys
1516
import traceback
1617
import six
@@ -22,6 +23,9 @@
2223
from tf2onnx.optimizer import IdentityOptimizer, TransposeOptimizer
2324
from tf2onnx.schemas import get_schema
2425

26+
logging.basicConfig(level=logging.INFO)
27+
log = logging.getLogger("graph")
28+
2529

2630
# todo(pengwa): remove protected-access later
2731
# pylint: disable=broad-except,protected-access
@@ -83,12 +87,16 @@ def attr(self):
8387

8488
@property
8589
def attr_onnx(self):
90+
"""Return onnx valid attributes"""
91+
schema = get_schema(self.type, self.graph.opset, self.domain)
92+
if schema is None and not (self.is_const() or self.is_graph_input()):
93+
log.warning("Node %s uses non-stardard onnx op <%s, %s>, skip attribute check", self.name, self.domain,
94+
self.type)
95+
8696
onnx_attrs = {}
8797
for a in self._attr.values():
88-
schema = get_schema(self.type, self.graph.opset)
89-
if schema:
90-
if schema.has_attribute(a.name):
91-
onnx_attrs[a.name] = a
98+
if schema is None or schema.has_attribute(a.name):
99+
onnx_attrs[a.name] = a
92100
return onnx_attrs
93101

94102
@property
@@ -358,8 +366,10 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
358366
self.remove_node(n.name)
359367

360368
new_outputs = [output if output != o else new_output_name for output in n.output]
369+
# domain should be passed to new node
361370
new_node = self.make_node(n.type, n.input, outputs=new_outputs, attr=n.attr, name=n.name,
362-
skip_conversion=n._skip_conversion, dtypes=n_dtypes, shapes=n_shapes)
371+
skip_conversion=n._skip_conversion, dtypes=n_dtypes, shapes=n_shapes,
372+
domain=n.domain)
363373

364374
if body_graphs:
365375
for attr_name, body_graph in body_graphs.items():
@@ -404,7 +414,7 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
404414
return node
405415

406416
def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, skip_conversion=True,
407-
op_name_scope=None, name=None, shapes=None, dtypes=None):
417+
op_name_scope=None, name=None, shapes=None, dtypes=None, domain=None):
408418
"""Make a new onnx node in the graph"""
409419
if attr is None:
410420
attr = {}
@@ -437,7 +447,7 @@ def make_node(self, op_type, inputs, attr=None, output_count=1, outputs=None, sk
437447
n = self.get_node_by_output_in_current_graph(o)
438448
utils.make_sure(n is None, "output tensor named %s already exists in node: \n%s", o, n)
439449

440-
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, **raw_attr)
450+
onnx_node = helper.make_node(op_type, inputs, outputs, name=name, domain=domain, **raw_attr)
441451

442452
if op_type in ["If", "Loop", "Scan"]:
443453
# we force the op containing inner graphs not skipped during conversion.
@@ -864,7 +874,7 @@ def remove_input(node, to_be_removed):
864874
# don't remove output from parent since others might depend on it
865875
return True
866876

867-
def insert_new_node_on_input(self, node, op_type, input_name, name=None, **kwargs):
877+
def insert_new_node_on_input(self, node, op_type, input_name, name=None, domain=None, **kwargs):
868878
"""Create and insert a new node into the graph.
869879
Args:
870880
node: we want to replace the input for this node
@@ -879,14 +889,14 @@ def insert_new_node_on_input(self, node, op_type, input_name, name=None, **kwarg
879889
if name is None:
880890
name = utils.make_name(node.name)
881891
new_output = port_name(name)
882-
new_node = self.make_node(op_type, [input_name], attr=kwargs, outputs=[new_output], name=name)
892+
new_node = self.make_node(op_type, [input_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
883893
for i, n in enumerate(node.input):
884894
if n == input_name:
885895
node.input[i] = new_output
886896
break
887897
return new_node
888898

889-
def insert_new_node_on_output(self, op_type, output_name, name=None, **kwargs):
899+
def insert_new_node_on_output(self, op_type, output_name, name=None, domain=None, **kwargs):
890900
"""Create and insert a new node into the graph.
891901
Args:
892902
op_type: type for new operation
@@ -903,7 +913,7 @@ def insert_new_node_on_output(self, op_type, output_name, name=None, **kwargs):
903913
type(op_type))
904914

905915
new_output = port_name(name)
906-
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name)
916+
new_node = self.make_node(op_type, [output_name], attr=kwargs, outputs=[new_output], name=name, domain=domain)
907917

908918
to_replace = [n for n in self.get_nodes() if n != new_node]
909919
self.replace_all_inputs(to_replace, output_name, new_output)
@@ -1042,7 +1052,7 @@ def optimize_graph(graph, doc_string, optimize=None, debug=False):
10421052
try:
10431053
opts = [TransposeOptimizer(graph, output_names=graph.outputs, debug=debug),
10441054
IdentityOptimizer(graph, output_names=graph.outputs, debug=debug)
1045-
]
1055+
]
10461056
for opt in opts:
10471057
opt.optimize()
10481058
model_proto = graph.make_model(doc_string, optimize=optimize)
@@ -1068,7 +1078,7 @@ def optimize_graph_with_model_proto(onnx_model_proto, debug=False):
10681078

10691079
opts = [TransposeOptimizer(g, output_names=g.outputs, debug=debug),
10701080
IdentityOptimizer(g, output_names=g.outputs, debug=debug)
1071-
]
1081+
]
10721082
for opt in opts:
10731083
opt.optimize()
10741084

tf2onnx/schemas.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from collections import defaultdict, OrderedDict
1313
from onnx import defs
1414

15-
1615
ONNX_DOMAIN = ""
1716

1817

@@ -77,16 +76,39 @@ def _register_all_schemas_with_history():
7776
return ordered_map
7877

7978

79+
def _parse_domain_opset_versions(schemas):
80+
""" Get max opset version among all schemas within each domain. """
81+
domain_opset_versions = dict()
82+
for domain_version_schema_map in schemas.values():
83+
for domain, version_schema_map in domain_version_schema_map.items():
84+
# version_schema_map is sorted by since_version in descend order
85+
max_version = next(iter(version_schema_map))
86+
if domain not in domain_opset_versions:
87+
domain_opset_versions[domain] = int(max_version)
88+
else:
89+
domain_opset_versions[domain] = max(domain_opset_versions[domain], int(max_version))
90+
return domain_opset_versions
91+
92+
8093
# format is <OpName, <Domain, <SinceVersion, OpSchema>>>
8194
# SinceVersion is sorted from high to low
8295
_schemas = _register_all_schemas_with_history()
8396

97+
_domain_opset_versions = _parse_domain_opset_versions(_schemas)
98+
8499

85100
def get_schema(name, max_inclusive_opset_version, domain=ONNX_DOMAIN):
86101
"""Get schema by name within specific version."""
102+
domain = domain or ONNX_DOMAIN
87103
domain_version_schema_map = _schemas[name]
88104
version_schema_map = domain_version_schema_map[domain]
89105
for version, schema in version_schema_map.items():
90106
if version <= max_inclusive_opset_version:
91107
return schema
92108
return None
109+
110+
111+
def get_max_supported_opset_version(domain=ONNX_DOMAIN):
112+
"""Get max supported opset version by current onnx package given a domain."""
113+
domain = domain or ONNX_DOMAIN
114+
return _domain_opset_versions.get(domain, None)

0 commit comments

Comments
 (0)