Skip to content

Commit 8edc4c9

Browse files
fix transpose bug
1 parent 343affe commit 8edc4c9

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

tests/test_optimizers.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tf2onnx import utils
1313
from tf2onnx.graph import GraphUtil
1414
from backend_test_base import Tf2OnnxBackendTestBase
15-
from common import unittest_main
15+
from common import unittest_main, group_nodes_by_type
1616

1717

1818
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -162,6 +162,33 @@ def test_transpose_with_identity(self):
162162
self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
163163
model_proto, remaining_transpose_num=1)
164164

165+
def test_trans_output_as_graph_outputs(self):
166+
"""
167+
If transpose's output is graph's output, don't optimize it.
168+
"""
169+
trans = helper.make_node("Transpose", ["X"], ["Y"], name="trans", perm=[0, 2, 3, 1])
170+
graph_proto = helper.make_graph(
171+
[trans],
172+
"trans-to-graph-output",
173+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
174+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 4, 5, 3))],
175+
)
176+
177+
graph = GraphUtil.create_graph_from_onnx_graph(graph_proto)
178+
# remove identity to graph output
179+
identity_op = graph.get_node_by_output(graph.outputs[0])
180+
graph.outputs = [identity_op.input[0]]
181+
graph.remove_node(identity_op.name)
182+
183+
optimized_graph = GraphUtil.optimize_graph(graph, "onnx-tests")
184+
185+
self.assertTrue(optimized_graph, msg="graph after optimizer should not be None")
186+
187+
trans_cnt = len(group_nodes_by_type(optimized_graph)["Transpose"])
188+
189+
self.assertTrue(trans_cnt == 1, msg="Expect 1 Transpose ops left, but actually " +
190+
str(trans_cnt) + " left")
191+
165192
# Tranpose Optimizer Tests End
166193

167194
# Identity Optimizer Tests Start
@@ -396,6 +423,5 @@ def test_duplicated_need_multiple_run(self):
396423
op_type="Log", remaining_op_num=3)
397424
# Merge Duplicated Nodes Optimizer Tests End
398425

399-
400426
if __name__ == "__main__":
401427
unittest_main()

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ def _switch_transpose_and_node(self, node, trans):
256256
# if return value is True, then it means Transpose is handled as designed
257257
# otherwise, it means that we skip handling since it is not in our support set
258258
def _handle_nhwc_tranpose(self, trans):
259+
if trans.output[0] in self._g.outputs:
260+
log.debug("%s connects to graph outputs, skip", trans.output[0])
261+
return False
259262
out_nodes = self._g.find_output_consumers(trans.output[0])
260263
if len(out_nodes) == 1:
261264
p = out_nodes[0]

0 commit comments

Comments
 (0)