Skip to content

Commit c02f793

Browse files
Don't copy/split transpose unless further optimization is likely (#1679)
* Don't copy/split tranpose unless further optimization is likely Signed-off-by: Tom Wildenhain <[email protected]> * bump ci Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 8c8e585 commit c02f793

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tests/test_optimizers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1436,7 +1436,6 @@ def test_two_transposes_switch_with_mul(self, shape, perm_input, perm_output):
14361436
model_proto, remaining_transpose_num=0)
14371437

14381438
@parameterized.expand([
1439-
((1, 6, 8), (8, 1, 6), [2, 0, 1], [1, 2, 0]),
14401439
((1, 6, 8, 9), (1, 8, 9, 6), [0, 2, 3, 1], [0, 3, 1, 2]),
14411440
((1, 6, 8, 9, 2), (1, 8, 9, 2, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
14421441
])

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,9 @@ def _handle_nhwc_tranpose(self, trans):
323323
op_handler = self._handler_map[p.type]
324324
return op_handler(trans, p)
325325
return False
326-
if out_nodes:
327-
# move transpose into branches to let Transposes can be "handled" in each branch
326+
if out_nodes and trans.get_attr_value("perm") in [NCHW_TO_NHWC, NCDHW_TO_NDHWC]:
327+
# Move transpose into branches to let Transposes can be "handled" in each branch.
328+
# This will add more transpose ops, so only do this if further optimization is likely (check perm).
328329
for n in out_nodes:
329330
branch_trans = n.graph.make_node("Transpose", [trans.input[0]], attr=trans.get_onnx_attrs())
330331
n.graph.replace_input(n, trans.output[0], branch_trans.output[0])

0 commit comments

Comments
 (0)