Skip to content

Commit d8b0df3

Browse files
authored
transpose optimizer wrongly removes a transpose (#1288)
* sig Signed-off-by: guschmue <[email protected]> * fix ut that depends on interla count Signed-off-by: guschmue <[email protected]> * pylint Signed-off-by: guschmue <[email protected]>
1 parent 1d07510 commit d8b0df3

File tree

3 files changed

+28
-25
lines changed

3 files changed

+28
-25
lines changed

tests/test_internals.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ def test_insert_node1(self):
8686
ops = g.get_nodes()
8787
g.topological_sort(ops)
8888
result = onnx_to_graphviz(g)
89-
expected = 'digraph { Placeholder__4 [op_type=Placeholder] ' \
89+
expected = 'digraph { Placeholder__5 [op_type=Placeholder] ' \
9090
'n1 [op_type=Abs] n7 [op_type=Abs] n2 [op_type=Abs] n3 [op_type=Abs] ' \
9191
'n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
92-
'n5_graph_outputs_Identity__3 [op_type=Identity] input -> n1 n1:0 -> n7 ' \
93-
'n7:0 -> n2 n1:0 -> n3 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5_raw_output___2:0 -> n6 ' \
94-
'n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 }'
92+
'n5_graph_outputs_Identity__4 [op_type=Identity] input -> n1 n1:0 -> n7 ' \
93+
'n7:0 -> n2 n1:0 -> n3 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 n5_raw_output___3:0 -> n6 ' \
94+
'n5_raw_output___3:0 -> n5_graph_outputs_Identity__4 }'
9595
self.assertEqual(expected, result)
9696

9797
def test_insert_node2(self):
@@ -101,11 +101,11 @@ def test_insert_node2(self):
101101
ops = g.get_nodes()
102102
g.topological_sort(ops)
103103
result = onnx_to_graphviz(g)
104-
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n7 [op_type=Abs] ' \
104+
expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] n7 [op_type=Abs] ' \
105105
'n3 [op_type=Abs] n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] ' \
106-
'n6 [op_type=Identity] n5_graph_outputs_Identity__3 [op_type=Identity] ' \
106+
'n6 [op_type=Identity] n5_graph_outputs_Identity__4 [op_type=Identity] ' \
107107
'input -> n1 n1:0 -> n7 n7:0 -> n3 n7:0 -> n2 n2:0 -> n4 n3:0 -> n4 n4:0 -> n5 ' \
108-
'n5_raw_output___2:0 -> n6 n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 }'
108+
'n5_raw_output___3:0 -> n6 n5_raw_output___3:0 -> n5_graph_outputs_Identity__4 }'
109109
self.assertEqual(expected, result)
110110

111111
def test_remove_input(self):
@@ -116,11 +116,11 @@ def test_remove_input(self):
116116
ops = g.get_nodes()
117117
g.topological_sort(ops)
118118
result = onnx_to_graphviz(g)
119-
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] n3 [op_type=Abs] ' \
119+
expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] n3 [op_type=Abs] ' \
120120
'n2 [op_type=Abs] n4 [op_type=Add] n5 [op_type=Abs] n6 [op_type=Identity] ' \
121-
'n5_graph_outputs_Identity__3 [op_type=Identity] input -> n1 n1:0 -> n3 ' \
122-
'n1:0 -> n2 n2:0 -> n4 n4:0 -> n5 n5_raw_output___2:0 -> n6 ' \
123-
'n5_raw_output___2:0 -> n5_graph_outputs_Identity__3 }'
121+
'n5_graph_outputs_Identity__4 [op_type=Identity] input -> n1 n1:0 -> n3 ' \
122+
'n1:0 -> n2 n2:0 -> n4 n4:0 -> n5 n5_raw_output___3:0 -> n6 ' \
123+
'n5_raw_output___3:0 -> n5_graph_outputs_Identity__4 }'
124124
self.assertEqual(expected, result)
125125

126126
def test_rewrite_subgraph(self):
@@ -144,11 +144,11 @@ def test_rewrite_subgraph(self):
144144
g.remove_node(n.name)
145145
g.topological_sort(ops)
146146
result = onnx_to_graphviz(g)
147-
expected = 'digraph { Placeholder__4 [op_type=Placeholder] n1 [op_type=Abs] ' \
148-
'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__5 [op_type=Sub] ' \
149-
'n6 [op_type=Identity] n5_graph_outputs_Identity__3 [op_type=Identity] ' \
150-
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__5 n3:0 -> ReplacedOp__5 ' \
151-
'ReplacedOp__5:0 -> n6 ReplacedOp__5:0 -> n5_graph_outputs_Identity__3 }'
147+
expected = 'digraph { Placeholder__5 [op_type=Placeholder] n1 [op_type=Abs] ' \
148+
'n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__6 [op_type=Sub] ' \
149+
'n6 [op_type=Identity] n5_graph_outputs_Identity__4 [op_type=Identity] ' \
150+
'input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__6 n3:0 -> ReplacedOp__6 ' \
151+
'ReplacedOp__6:0 -> n6 ReplacedOp__6:0 -> n5_graph_outputs_Identity__4 }'
152152
self.assertEqual(expected, result)
153153

154154
def test_match_flipped(self):

tf2onnx/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
462462
self._output_to_consumers = {}
463463
self._input_to_graph = {}
464464
self.shapes = {}
465-
self.graph_name = graph_name or "tf2onnx"
465+
self.graph_name = graph_name or utils.make_name("tf2onnx")
466466
self._is_subgraph = is_subgraph
467467
self.ta_reads = []
468468
self.func_inputs = []
@@ -1006,7 +1006,7 @@ def _get_unvisited_child(g, node, not_visited):
10061006
all_input = list(filter(lambda a: a != '', all_input))
10071007
for inp in sorted(all_input):
10081008
j = self.get_node_by_output(inp)
1009-
utils.make_sure(j is not None, "Cannot find node with output %r", inp)
1009+
utils.make_sure(j is not None, "Cannot find node with output %r in graph %r", inp, self.graph_name)
10101010
if self.parent_graph and j.name not in op_name_to_index:
10111011
# there might be some outer-scoped inputs for an inner Graph.
10121012
pass

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _calculate_new_shape(graph, op):
100100
nodes = self.nodes
101101
# if channel==1 or height==width==1, replace transpose with reshape
102102
# replacing trans with reshape is because transpose will copy data even if this transpose doesn't nothing
103+
need_sort = False
103104
for op in nodes:
104105
if op.type == "Transpose":
105106
input_shape = self._g.get_shape(op.input[0])
@@ -112,7 +113,9 @@ def _calculate_new_shape(graph, op):
112113
# replace transpose with reshape
113114
self._g.remove_node(op.name)
114115
self._g.make_node("Reshape", [op.input[0], new_shape], name=op.name, outputs=op.output)
115-
self._g.topological_sort(self._g.get_nodes())
116+
need_sort = True
117+
if need_sort:
118+
self._g.topological_sort(self._g.get_nodes())
116119

117120
def merge_duplicated_transposes(self):
118121
# strategy used in previous procedure is to move transpose nodes down if possible,
@@ -283,12 +286,12 @@ def _handle_nhwc_tranpose(self, trans):
283286
op_handler = self._handler_map[p.type]
284287
return op_handler(trans, p)
285288
return False
286-
# move transpose into branches to let Transposes can be "handled" in each branch
287-
for n in out_nodes:
288-
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.get_onnx_attrs())
289-
n.graph.replace_input(n, trans.output[0], branch_trans.output[0])
290-
291-
self._g.remove_node(trans.name)
289+
if out_nodes:
290+
# move transpose into branches to let Transposes can be "handled" in each branch
291+
for n in out_nodes:
292+
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.get_onnx_attrs())
293+
n.graph.replace_input(n, trans.output[0], branch_trans.output[0])
294+
self._g.remove_node(trans.name)
292295
return False
293296

294297
def _remove_useless_tranpose(self, trans):

0 commit comments

Comments
 (0)