@@ -439,10 +439,12 @@ def _maxmin_handler(self, trans, node):
439
439
def _mul_handler (self , trans , node ):
440
440
multiplier_input_id = None
441
441
multiplier_input_node = None
442
- for i , input_node in zip (node .input , node .inputs ):
443
- if i != trans .output [0 ]:
444
- multiplier_input_id = i
442
+ multiplier_input_idx = None
443
+ for idx , (input_id , input_node ) in enumerate (zip (node .input , node .inputs )):
444
+ if input_id != trans .output [0 ]:
445
+ multiplier_input_id = input_id
445
446
multiplier_input_node = input_node
447
+ multiplier_input_idx = idx
446
448
447
449
# node's inputs may come from one same node. if so the multiplier_input_node may be none
448
450
if multiplier_input_node is None :
@@ -492,6 +494,11 @@ def _mul_handler(self, trans, node):
492
494
# shape is (1)
493
495
return self ._switch_transpose_and_node (node , trans )
494
496
497
+ if not self ._nodes_has_single_consumer_node ([multiplier_input_node ]):
498
+ new_inp = self ._g .copy_const (multiplier_input_node )
499
+ self ._g .replace_input (node , multiplier_input_id , new_inp .output [0 ], multiplier_input_idx )
500
+ multiplier_input_node = new_inp
501
+
495
502
# shape is (N). reshape so that trans(shape) = 1,1,...,N
496
503
perm = list (trans .get_attr ('perm' ).ints )
497
504
new_shape = np .ones (len (perm ), dtype = np .int32 )
0 commit comments