Skip to content

Commit a0ab29f

Browse files
authored
Merge pull request #540 from zhijxu-MS/PR_resize_loop_trans_opt
fix several bugs in transpose opt
2 parents 89aca79 + 366dda9 commit a0ab29f

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

tests/run_pretrained_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def run_caffe2(self, name, model_proto, inputs):
170170
def run_onnxruntime(self, name, model_proto, inputs):
171171
"""Run test against onnxruntime backend."""
172172
import onnxruntime as rt
173-
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=True)
173+
model_path = utils.save_onnx_model(TEMP_DIR, name, inputs, model_proto, include_test_data=True,
174+
as_text=utils.is_debug_mode())
174175
logger.info("Model saved to %s", model_path)
175176
m = rt.InferenceSession(model_path)
176177
results = m.run(self.output_names, inputs)

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def merge_duplicated_transposes(self):
131131
graph.delete_unused_nodes(graph.outputs)
132132

133133
def _optimize(self, graph):
134+
return self._apply_optimization(graph, self._optimize_at_current_graph_level)
135+
136+
def _optimize_at_current_graph_level(self, graph):
134137
self._g = graph
135138
self.pre_optimize_action()
136139
no_action = False
@@ -190,7 +193,7 @@ def _handle_node_having_branches(self, node):
190193

191194
# make sure node's all input transpose all have only 1 consumer node,
192195
# otherwise, it would impact their other output nodes
193-
if self._transpose_has_single_consumer_node(node.inputs):
196+
if self._nodes_has_single_consumer_node(node.inputs):
194197
self._create_transpose_pairs_after_node(node)
195198
input_transposes = node.inputs
196199
for n in input_transposes:
@@ -226,7 +229,7 @@ def _get_input_index_for_trans(self, node, trans):
226229

227230
# the assumption is: both node and trans have only 1 output
228231
def _switch_transpose_and_node(self, node, trans):
229-
if not self._transpose_has_single_consumer_node([trans]):
232+
if not self._nodes_has_single_consumer_node([trans]):
230233
return False
231234

232235
input_index = self._get_input_index_for_trans(node, trans)
@@ -274,13 +277,12 @@ def _remove_useless_tranpose(self, trans):
274277
self._g.replace_all_inputs(self._g.get_nodes(), trans.output[0], trans.input[0])
275278
self._g.remove_node(trans.name)
276279

277-
def _transpose_has_single_consumer_node(self, trans_nodes):
278-
result = True
279-
for n in trans_nodes:
280-
cnt = len(set(self._g.find_output_consumers(n.output[0])))
281-
result = result and cnt == 1
282-
if not result:
283-
return False
280+
def _nodes_has_single_consumer_node(self, nodes):
281+
for n in nodes:
282+
for output in n.output:
283+
cnt = len(set(self._g.find_output_consumers(output)))
284+
if cnt != 1:
285+
return False
284286
return True
285287

286288
def _get_non_nchw_transpose_output_nodes(self, node):
@@ -472,7 +474,7 @@ def _simple_through_handler(self, trans, node):
472474

473475
def _shape_handler(self, trans, node):
474476
# input > trans > shape can be changed into input > shape > gather
475-
if not self._transpose_has_single_consumer_node([trans]):
477+
if not self._nodes_has_single_consumer_node([trans]):
476478
return False
477479

478480
output_shape = self._g.get_shape(node.output[0])

0 commit comments

Comments
 (0)