Skip to content

Commit 4c75cb8

Browse files
authored
Merge pull request #840 from jignparm/jignparm/optimize_transpose_multiply
Push transposes past mul ops to optimize mobilenet_v3
2 parents 14af496 + c539368 commit 4c75cb8

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

tests/test_optimizers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,28 @@ def test_trans_can_be_replaced_with_reshape2(self):
687687
self.run_transpose_compare(["Y"], {"X": np.random.randn(*input_shape_np).astype(np.float32)},
688688
model_proto, remaining_transpose_num=0)
689689

690+
def test_two_transposes_switch_with_mul(self):
691+
const_node = self._make_onnx_const(np.array(10, dtype=np.float32), "const_10")
692+
node0 = helper.make_node("Transpose", ["u1"], ["v1"], perm=[0, 2, 3, 1], name="trans_0")
693+
node1 = helper.make_node("Transpose", ["u2"], ["v2"], perm=[0, 2, 3, 1], name="trans_1")
694+
695+
node2 = helper.make_node("Mul", ["v1", "v2"], ["x"], name="mul_1")
696+
node3 = helper.make_node("Mul", ["x", const_node.output[0]], ["y"], name="mul_2")
697+
node4 = helper.make_node("Transpose", ["y"], ["res"], perm=[0, 3, 1, 2], name="trans_3")
698+
699+
graph = helper.make_graph(
700+
[const_node, node0, node1, node2, node3, node4],
701+
"test-transpose-mul",
702+
[helper.make_tensor_value_info("u1", TensorProto.FLOAT, (1, 6, 8, 9)),
703+
helper.make_tensor_value_info("u2", TensorProto.FLOAT, (1, 6, 8, 9))],
704+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 6, 8, 9))],
705+
)
706+
707+
model_proto = self.make_model(graph, producer_name="onnx-tests")
708+
self.run_transpose_compare(["res"], {"u1": np.random.randn(1, 6, 8, 9).astype(np.float32),
709+
"u2": np.random.randn(1, 6, 8, 9).astype(np.float32)},
710+
model_proto, remaining_transpose_num=0)
711+
690712
# Tranpose Optimizer Tests End
691713

692714
# Identity Optimizer Tests Start

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,24 @@ def _mul_handler(self, trans, node):
433433
multiplier_input_node = input_node
434434

435435
# node's inputs may come from one same node. if so the multiplier_input_node may be none
436-
if multiplier_input_node is None or not multiplier_input_node.is_const():
436+
if multiplier_input_node is None:
437+
return False
438+
439+
# convert mul(trans(x), trans(y)) -> trans(mul(x, y))
440+
if multiplier_input_node.type == "Transpose":
441+
if is_nhwc_transpose(multiplier_input_node):
442+
if not self._nodes_has_single_consumer_node([multiplier_input_node]):
443+
return False
444+
input_index = self._get_input_index_for_trans(node, multiplier_input_node)
445+
if not self._switch_transpose_and_node(node, trans):
446+
return False
447+
448+
node.input[input_index] = multiplier_input_node.input[0]
449+
self._g.remove_node(multiplier_input_node.name)
450+
return True
451+
452+
# handle const multipliers
453+
if not multiplier_input_node.is_const():
437454
return False
438455
multiplier = multiplier_input_node.get_tensor_value(as_list=False)
439456

0 commit comments

Comments
 (0)