Skip to content

Commit 4224f40

Browse files
Handle Mul as Square in TransposeOptimizer (#1339)
Signed-off-by: Mateusz Tabaka <[email protected]> Co-authored-by: TomWildenhain-Microsoft <[email protected]>
1 parent e8842ba commit 4224f40

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

tests/test_optimizers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,28 @@ def test_transpose_merge(self, input_shape1, input_shape2, perm):
477477
self.run_transpose_compare(["OUT"], {"X": np.random.randn(*input_shape1).astype(np.float32)},
478478
model_proto, remaining_transpose_num=1)
479479

480+
481+
@parameterized.expand([
482+
((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),
483+
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
484+
])
485+
def test_transpose_mul_as_square(self, shape, perm_input, perm_output):
486+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans")
487+
node1 = helper.make_node("Mul", ["Y", "Y"], ["Z"], name="mul")
488+
node2 = helper.make_node("Transpose", ["Z"], ["OUT"], perm=perm_output, name="trans_1")
489+
490+
graph = helper.make_graph(
491+
[node0, node1, node2],
492+
"transpose-mul-as-sqr-test",
493+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
494+
[helper.make_tensor_value_info("OUT", TensorProto.FLOAT, shape)],
495+
)
496+
497+
model_proto = self.make_model(graph, producer_name="onnx-tests")
498+
self.run_transpose_compare(["OUT"], {"X": np.random.randn(*shape).astype(np.float32)},
499+
model_proto, remaining_transpose_num=0)
500+
501+
480502
@parameterized.expand([
481503
((2, 3, 4, 5), [0, 2, 3, 1]),
482504
((2, 3, 4, 5, 6), [0, 2, 3, 4, 1]),

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,13 @@ def _mul_handler(self, trans, node):
460460

461461
# node's inputs may come from one same node. if so the multiplier_input_node may be none
462462
if multiplier_input_node is None:
463-
return False
463+
if not self._nodes_has_single_consumer_node([trans]):
464+
return False
465+
self._g.replace_all_inputs(node.output[0], trans.output[0])
466+
self._g.replace_input(node, node.input[0], trans.input[0], 0)
467+
self._g.replace_input(node, node.input[1], trans.input[0], 1)
468+
self._g.replace_input(trans, trans.input[0], node.output[0], 0)
469+
return True
464470

465471
# convert mul(trans(x), trans(y)) -> trans(mul(x, y))
466472
if multiplier_input_node.type == "Transpose":

0 commit comments

Comments
 (0)