Skip to content

Commit f407c31

Browse files
authored
Merge pull request #978 from jignparm/jignparm/transpose_mul
Enhance Transpose optimizer "Mul" handler
2 parents f5e4ed3 + e332323 commit f407c31

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

tests/test_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def test_trans_can_be_replaced_with_reshape2(self):
691691
model_proto, remaining_transpose_num=0)
692692

693693
def test_two_transposes_switch_with_mul(self):
694-
const_node = self._make_onnx_const(np.array(10, dtype=np.float32), "const_10")
694+
const_node = self._make_onnx_const(np.array(np.random.random(6), dtype=np.float32), "const_10")
695695
node0 = helper.make_node("Transpose", ["u1"], ["v1"], perm=[0, 2, 3, 1], name="trans_0")
696696
node1 = helper.make_node("Transpose", ["u2"], ["v2"], perm=[0, 2, 3, 1], name="trans_1")
697697

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,21 @@ def _mul_handler(self, trans, node):
482482
self._g.remove_node(node.name)
483483
return True
484484

485-
# if the shape is () or (1), we just move transpose after the mul
486-
if not multiplier.shape or (len(multiplier.shape) == 1 and multiplier.shape[0] == 1):
485+
# if the shape is (), we just move transpose after the mul
486+
if not multiplier.shape:
487+
return self._switch_transpose_and_node(node, trans)
488+
489+
# if multiplier is 1-D
490+
if len(multiplier.shape) == 1:
491+
if multiplier.shape[0] == 1:
492+
# shape is (1)
493+
return self._switch_transpose_and_node(node, trans)
494+
495+
# shape is (N). reshape so that trans(shape) = 1,1,...,N
496+
perm = list(trans.get_attr('perm').ints)
497+
new_shape = np.ones(len(perm), dtype=np.int32)
498+
new_shape[perm[-1]] = multiplier.shape[0]
499+
multiplier_input_node.set_tensor_value(multiplier.reshape(new_shape))
487500
return self._switch_transpose_and_node(node, trans)
488501

489502
return False

0 commit comments

Comments
 (0)