Skip to content

Commit 0918351

Browse files
committed
Merge branch 'master' of https://github.com/onnx/tensorflow-onnx into opset_9_scan
2 parents aa6f345 + dba60d5 commit 0918351

File tree

9 files changed

+322
-100
lines changed

9 files changed

+322
-100
lines changed

tests/run_pretrained_models.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def run_tensorflow(self, sess, inputs):
153153

154154
def to_onnx(self, tf_graph, opset=None, shape_override=None, input_names=None):
155155
"""Convert graph to tensorflow."""
156-
return process_tf_graph(tf_graph, continue_on_error=True, verbose=True, opset=opset,
156+
return process_tf_graph(tf_graph, continue_on_error=False, verbose=True, opset=opset,
157157
target=Test.target, shape_override=shape_override,
158158
input_names=input_names, output_names=self.output_names)
159159

@@ -186,7 +186,6 @@ def run_onnxruntime(self, name, model_proto, inputs):
186186
"""Run test against msrt-next backend."""
187187
import onnxruntime as rt
188188
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=True)
189-
utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=False, as_text=True)
190189
print("\t\t" + model_path)
191190
m = rt.InferenceSession(model_path)
192191
results = m.run(self.output_names, inputs)
@@ -266,7 +265,7 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
266265
onnx_graph = self.to_onnx(sess.graph, opset=opset, shape_override=shape_override,
267266
input_names=inputs.keys())
268267
model_proto = onnx_graph.make_model("converted from tf2onnx")
269-
new_model_proto = GraphUtil.opt_transposes_with_graph(onnx_graph, "test", debug=debug)
268+
new_model_proto = GraphUtil.optimize_graph(onnx_graph, "test", debug=debug)
270269
if new_model_proto:
271270
model_proto = new_model_proto
272271
else:

tests/test_backend.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -620,23 +620,27 @@ def test_logicaland(self):
620620

621621
@check_onnxruntime_incompatibility("Greater")
622622
def test_greater(self):
623-
x_val1 = np.array([4, 2, 4, 1], dtype=np.float32).reshape((2, 2))
624-
x_val2 = np.array([2, 4, 4, 1], dtype=np.float32).reshape((2, 2))
625-
x1 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
626-
x2 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT1)
627-
mi = tf.greater(x1, x2)
628-
_ = tf.identity(mi, name=_TFOUTPUT)
629-
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
623+
for op in [tf.greater, tf.greater_equal]:
624+
tf.reset_default_graph()
625+
x_val1 = np.array([4, 2, 4, 1], dtype=np.float32).reshape((2, 2))
626+
x_val2 = np.array([2, 4, 4, 1], dtype=np.float32).reshape((2, 2))
627+
x1 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
628+
x2 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT1)
629+
mi = op(x1, x2)
630+
_ = tf.identity(mi, name=_TFOUTPUT)
631+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
630632

631633
@check_onnxruntime_incompatibility("Greater")
632634
def test_greater_unsupport_type(self):
633-
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
634-
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
635-
x1 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT)
636-
x2 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT1)
637-
mi = tf.greater(x1, x2)
638-
_ = tf.identity(mi, name=_TFOUTPUT)
639-
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
635+
for op in [tf.greater, tf.greater_equal]:
636+
tf.reset_default_graph()
637+
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
638+
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
639+
x1 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT)
640+
x2 = tf.placeholder(tf.int32, [2, 2], name=_TFINPUT1)
641+
mi = op(x1, x2)
642+
_ = tf.identity(mi, name=_TFOUTPUT)
643+
self._run_test_case([_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})
640644

641645
@check_onnxruntime_incompatibility("Less")
642646
def test_less(self):

tests/test_optimizers.py

Lines changed: 156 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
from onnx import helper, TensorProto
12+
from tf2onnx import utils
1213
from tf2onnx.graph import GraphUtil
1314
from backend_test_base import Tf2OnnxBackendTestBase
1415
from common import unittest_main
@@ -19,22 +20,23 @@
1920
class OptimizerTests(Tf2OnnxBackendTestBase):
2021
"""Run original model proto and modified model proto with onnxruntime, compare the results."""
2122

22-
def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
23-
remaining_transpose_num=None, debug=False, rtol=1e-07):
23+
def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto, op_type,
24+
remaining_op_num, debug=False, rtol=1e-07):
25+
utils.make_sure(op_type is not None, "op_type should be specified")
26+
utils.make_sure(remaining_op_num is not None, "remaining_op_num should be specified")
27+
2428
origin_model_path = self.save_onnx_model(origin_proto, onnx_feed_dict, postfix="_origin")
2529

26-
new_proto = GraphUtil.opt_transposes_with_model_proto(origin_proto)
30+
new_proto = GraphUtil.optimize_graph_with_model_proto(origin_proto)
2731

2832
self.assertTrue(new_proto, msg="model proto after optimizer should not be None")
2933

3034
new_model_path = self.save_onnx_model(new_proto, onnx_feed_dict, postfix="_opt")
31-
32-
previous = GraphUtil.get_node_count_from_onnx_graph(origin_proto.graph)
3335
current = GraphUtil.get_node_count_from_onnx_graph(new_proto.graph)
3436

35-
self.assertTrue(current["Transpose"] < previous["Transpose"], msg="transpose ops count not changed")
36-
if remaining_transpose_num is not None:
37-
self.assertTrue(current["Transpose"] == remaining_transpose_num, msg="some transpose ops left unexpected")
37+
self.assertTrue(current[op_type] == remaining_op_num,
38+
msg="Expect " + str(remaining_op_num) + " " + op_type + " ops left, but actually " +
39+
str(current[op_type]) + " left")
3840

3941
if self.config.is_onnxruntime_backend:
4042
expected = self.run_onnxruntime(origin_model_path, onnx_feed_dict, output_names_with_port)
@@ -47,7 +49,14 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
4749
self.assertEqual(expected_val.dtype, actual_val.dtype)
4850
self.assertEqual(expected_val.shape, actual_val.shape)
4951

50-
def test_relu(self):
52+
# Tranpose Optimizer Tests Start
53+
54+
def run_transpose_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
55+
remaining_transpose_num=None, debug=False, rtol=1e-07):
56+
self.run_and_compare(output_names_with_port, onnx_feed_dict, origin_proto, op_type="Transpose",
57+
remaining_op_num=remaining_transpose_num, debug=debug, rtol=rtol)
58+
59+
def test_transpose_relu(self):
5160
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
5261
node2 = helper.make_node("Relu", ["Y"], ["Z"], name="relu")
5362
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")
@@ -60,10 +69,10 @@ def test_relu(self):
6069
)
6170

6271
model_proto = helper.make_model(graph, producer_name="onnx-tests")
63-
self.run_and_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
64-
model_proto, remaining_transpose_num=0)
72+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
73+
model_proto, remaining_transpose_num=0)
6574

66-
def test_leaky_relu(self):
75+
def test_transpose_leaky_relu(self):
6776
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
6877
node2 = helper.make_node("LeakyRelu", ["Y"], ["Z"], alpha=0.02, name="relu")
6978
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=[0, 3, 1, 2], name="trans_2")
@@ -76,10 +85,10 @@ def test_leaky_relu(self):
7685
)
7786

7887
model_proto = helper.make_model(graph, producer_name="onnx-tests")
79-
self.run_and_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
80-
model_proto, remaining_transpose_num=0)
88+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
89+
model_proto, remaining_transpose_num=0)
8190

82-
def test_max(self):
91+
def test_transpose_max(self):
8392
const_1_val = [2.0]
8493
const_1 = helper.make_tensor("const_1", TensorProto.FLOAT, (1,), const_1_val)
8594
const_1_node = helper.make_node("Constant", [], ["const_1"], value=const_1, name="const_1")
@@ -104,8 +113,8 @@ def test_max(self):
104113
)
105114

106115
model_proto = helper.make_model(graph, producer_name="onnx-tests")
107-
self.run_and_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
108-
model_proto, remaining_transpose_num=0)
116+
self.run_transpose_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
117+
model_proto, remaining_transpose_num=0)
109118

110119
def test_transpose_merge(self):
111120
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
@@ -120,8 +129,8 @@ def test_transpose_merge(self):
120129
)
121130

122131
model_proto = helper.make_model(graph, producer_name="onnx-tests")
123-
self.run_and_compare(["OUT"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
124-
model_proto, remaining_transpose_num=1)
132+
self.run_transpose_compare(["OUT"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
133+
model_proto, remaining_transpose_num=1)
125134

126135
def test_transpose_with_shape(self):
127136
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
@@ -135,9 +144,135 @@ def test_transpose_with_shape(self):
135144
)
136145

137146
model_proto = helper.make_model(graph, producer_name="onnx-tests")
138-
self.run_and_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
139-
model_proto, remaining_transpose_num=0)
147+
self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
148+
model_proto, remaining_transpose_num=0)
149+
150+
# Tranpose Optimizer Tests End
151+
152+
# Identity Optimizer Tests Start
153+
154+
def run_identity_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
155+
remaining_identity_num=None, debug=False, rtol=1e-07):
156+
self.run_and_compare(output_names_with_port, onnx_feed_dict, origin_proto, op_type="Identity",
157+
remaining_op_num=remaining_identity_num, debug=debug, rtol=rtol)
158+
159+
def test_identity_non_graph_output(self):
160+
node1 = helper.make_node("Add", ["X", "X"], ["Y"], name="add")
161+
node2 = helper.make_node("Identity", ["Y"], ["Z"], name="identity")
162+
node3 = helper.make_node("Shape", ["Z"], ["Z1"], name="shape")
163+
164+
graph = helper.make_graph(
165+
[node1, node2, node3],
166+
"identity-test",
167+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
168+
[helper.make_tensor_value_info("Z1", TensorProto.INT64, [4])],
169+
)
170+
171+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
172+
self.run_identity_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
173+
model_proto, remaining_identity_num=0)
174+
175+
def test_identity_unremovable_identity(self):
176+
# should not remove!!
177+
node1 = helper.make_node("Identity", ["X"], ["Y"], name="identity")
178+
179+
graph = helper.make_graph(
180+
[node1],
181+
"identity-test",
182+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
183+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 3, 4, 5))],
184+
)
185+
186+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
187+
self.run_identity_compare(["Y"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
188+
model_proto, remaining_identity_num=1)
189+
190+
def test_identity_output_as_multiple_graph_outputs(self):
191+
# handle case like this, both Identity nodes are graph outputs,
192+
# Add
193+
# / \
194+
# Identity Identity
195+
# We at most can remove one Identity for this case.
196+
node1 = helper.make_node("Add", ["X", "X"], ["Y"], name="identity")
197+
node2 = helper.make_node("Identity", ["Y"], ["Z1"], name="identity2")
198+
node3 = helper.make_node("Identity", ["Y"], ["Z2"], name="identity3")
199+
graph = helper.make_graph(
200+
[node1, node2, node3],
201+
"identity-test",
202+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
203+
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, (2, 3, 4, 5)),
204+
helper.make_tensor_value_info("Z2", TensorProto.FLOAT, (2, 3, 4, 5))],
205+
)
206+
207+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
208+
self.run_identity_compare(["Z1", "Z2"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
209+
model_proto, remaining_identity_num=1)
210+
211+
def test_identity_in_subgraph_non_graph_output(self):
212+
node1 = helper.make_node("Add", ["X", "X"], ["Y"], name="add")
213+
214+
iter_num_value = np.array(1, dtype=np.int64)
215+
node2 = helper.make_node(
216+
'Constant',
217+
inputs=[],
218+
outputs=['iterate_num_value'],
219+
value=helper.make_tensor(
220+
name='iterate_num_value',
221+
data_type=TensorProto.INT64,
222+
dims=iter_num_value.shape,
223+
vals=iter_num_value.flatten().astype(np.int64),
224+
),
225+
)
226+
227+
cond_value = np.array(True, dtype=np.bool)
228+
node3 = helper.make_node(
229+
'Constant',
230+
inputs=[],
231+
outputs=['cond_value'],
232+
value=helper.make_tensor(
233+
name='cond_value',
234+
data_type=TensorProto.BOOL,
235+
dims=iter_num_value.shape,
236+
vals=cond_value.flatten().astype(np.bool),
237+
),
238+
)
239+
240+
# sub graph
241+
sub_node1 = helper.make_node("Add", ["loop_var_1", "loop_var_1"], ["SubY"], name="sub_add")
242+
sub_node2 = helper.make_node("Identity", ["SubY"], ["SubIdentity1"], name="sub_identity_1")
243+
sub_node3 = helper.make_node("Identity", ["SubIdentity1"], ["loop_var_out_1"], name="sub_identity_2")
244+
sub_node4 = helper.make_node("Identity", ["loop_condition"], ["loop_cond_output"], name="sub_identity_3")
245+
sub_graph = helper.make_graph(
246+
[sub_node1, sub_node2, sub_node3, sub_node4],
247+
"identity_subgraph-test",
248+
[helper.make_tensor_value_info("loop_iter_num", TensorProto.INT64, (1,)), # iteration_num
249+
helper.make_tensor_value_info("loop_condition", TensorProto.BOOL, ()), # condition
250+
helper.make_tensor_value_info("loop_var_1", TensorProto.FLOAT, ()), # loop-carried dependency
251+
],
252+
[helper.make_tensor_value_info("loop_cond_output", TensorProto.BOOL, ()),
253+
helper.make_tensor_value_info("loop_var_out_1", TensorProto.FLOAT, ())
254+
],
255+
)
256+
# sub graph ends
257+
258+
loop_node = helper.make_node("Loop", ["iterate_num_value", "cond_value", "Y"], ["loop_var_1_output"],
259+
name="loop", body=sub_graph)
260+
261+
node4 = helper.make_node("Identity", ["loop_var_1_output"], ["Z"], name="identity")
262+
node5 = helper.make_node("Shape", ["Z"], ["Z1"], name="shape")
263+
264+
graph = helper.make_graph(
265+
[node1, node2, node3, loop_node, node4, node5],
266+
"identity-test",
267+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
268+
[helper.make_tensor_value_info("Z1", TensorProto.INT64, [4])],
269+
)
270+
271+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
272+
self.run_identity_compare(["Z1"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
273+
model_proto, remaining_identity_num=0)
140274

275+
# Tranpose Optimizer Tests End
141276

142277
if __name__ == "__main__":
143278
unittest_main()

tf2onnx/convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ def main():
126126

127127
model_proto = g.make_model("converted from {}".format(args.input))
128128

129-
new_model_proto = GraphUtil.opt_transposes_with_graph(g, "converted from {}".format(model_path),
130-
optimize=not args.continue_on_error)
129+
new_model_proto = GraphUtil.optimize_graph(g, "converted from {}".format(model_path),
130+
optimize=not args.continue_on_error)
131131
if new_model_proto:
132132
model_proto = new_model_proto
133133
else:

tf2onnx/graph.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from onnx import helper, numpy_helper, optimizer, shape_inference, OperatorSetIdProto, AttributeProto
2020
from tf2onnx import utils, __version__
2121
from tf2onnx.utils import port_name, find_opset
22-
from tf2onnx.optimizer.transpose_optimizer import TransposeOptimizer
22+
from tf2onnx.optimizer import IdentityOptimizer, TransposeOptimizer
2323
from tf2onnx.schemas import get_schema
2424

2525

@@ -854,6 +854,10 @@ def dump_node_statistics(self):
854854
op_cnt = collections.Counter()
855855
for n in self.get_nodes():
856856
op_cnt[n.type] += 1
857+
body_graphs = n.get_body_graphs()
858+
if body_graphs:
859+
for _, b_g in body_graphs.items():
860+
op_cnt += b_g.dump_node_statistics()
857861

858862
return op_cnt
859863

@@ -1040,16 +1044,19 @@ class GraphUtil(object):
10401044
"""Utilities for Graph manipulation."""
10411045

10421046
@staticmethod
1043-
def opt_transposes_with_graph(graph, doc_string, optimize=None, debug=False):
1044-
"""Optimize the graph, eliminating all useless Transpose pairs.
1047+
def optimize_graph(graph, doc_string, optimize=None, debug=False):
1048+
"""Optimize the graph, for example: eliminating all useless Transpose/Identity pairs.
10451049
10461050
Returns:
10471051
model proto after optimization, if optimizer run successfully
10481052
or None, if exceptions happen
10491053
"""
10501054
try:
1051-
opt = TransposeOptimizer(graph, output_names=graph.outputs, debug=debug)
1052-
opt.optimize()
1055+
opts = [TransposeOptimizer(graph, output_names=graph.outputs, debug=debug),
1056+
IdentityOptimizer(graph, output_names=graph.outputs, debug=debug)
1057+
]
1058+
for opt in opts:
1059+
opt.optimize()
10531060
model_proto = graph.make_model(doc_string, optimize=optimize)
10541061
return model_proto
10551062
except Exception:
@@ -1060,19 +1067,22 @@ def opt_transposes_with_graph(graph, doc_string, optimize=None, debug=False):
10601067
return None
10611068

10621069
@staticmethod
1063-
def opt_transposes_with_model_proto(onnx_model_proto, debug=False):
1064-
"""Optimize the model proto, eliminating all useless Transpose pairs.
1070+
def optimize_graph_with_model_proto(onnx_model_proto, debug=False):
1071+
"""Optimize the model proto, for example: eliminating all useless Transpose pairs.
10651072
10661073
Returns:
10671074
model proto after optimization, if optimizer run successfully
10681075
or None, if exceptions happens
10691076
"""
10701077
try:
10711078
kwargs = GraphUtil.get_onnx_model_properties(onnx_model_proto)
1072-
10731079
g = GraphUtil.create_graph_from_onnx_model(onnx_model_proto)
1074-
opt = TransposeOptimizer(g, output_names=g.outputs, debug=debug)
1075-
opt.optimize()
1080+
1081+
opts = [TransposeOptimizer(g, output_names=g.outputs, debug=debug),
1082+
IdentityOptimizer(g, output_names=g.outputs, debug=debug)
1083+
]
1084+
for opt in opts:
1085+
opt.optimize()
10761086

10771087
model_proto = g.make_model(onnx_model_proto.graph.doc_string,
10781088
graph_name=onnx_model_proto.graph.name, **kwargs)

0 commit comments

Comments
 (0)