Skip to content

Commit 33ee88f

Browse files
Fixed bug in Transpose optimizer mul handler for consts with multiple consumers (#1255)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent eecd3a4 commit 33ee88f

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,12 @@ def _maxmin_handler(self, trans, node):
439439
def _mul_handler(self, trans, node):
440440
multiplier_input_id = None
441441
multiplier_input_node = None
442-
for i, input_node in zip(node.input, node.inputs):
443-
if i != trans.output[0]:
444-
multiplier_input_id = i
442+
multiplier_input_idx = None
443+
for idx, (input_id, input_node) in enumerate(zip(node.input, node.inputs)):
444+
if input_id != trans.output[0]:
445+
multiplier_input_id = input_id
445446
multiplier_input_node = input_node
447+
multiplier_input_idx = idx
446448

447449
# node's inputs may come from one same node. if so the multiplier_input_node may be none
448450
if multiplier_input_node is None:
@@ -492,6 +494,11 @@ def _mul_handler(self, trans, node):
492494
# shape is (1)
493495
return self._switch_transpose_and_node(node, trans)
494496

497+
if not self._nodes_has_single_consumer_node([multiplier_input_node]):
498+
new_inp = self._g.copy_const(multiplier_input_node)
499+
self._g.replace_input(node, multiplier_input_id, new_inp.output[0], multiplier_input_idx)
500+
multiplier_input_node = new_inp
501+
495502
# shape is (N). reshape so that trans(shape) = 1,1,...,N
496503
perm = list(trans.get_attr('perm').ints)
497504
new_shape = np.ones(len(perm), dtype=np.int32)

0 commit comments

Comments
 (0)