@@ -251,8 +251,8 @@ def _switch_transpose_and_node(self, node, trans):
251
251
252
252
ops = self ._g .get_nodes ()
253
253
self ._g .replace_all_inputs (ops , node .output [0 ], trans .output [0 ])
254
- node .input [input_index ] = trans .input [0 ]
255
- trans .input [0 ] = node .output [0 ]
254
+ self . _g . replace_input ( node , node .input [input_index ], trans .input [0 ], input_index )
255
+ self . _g . replace_input ( trans , trans .input [0 ], node .output [0 ], 0 )
256
256
257
257
# need to transpose node shape in backward direction as well after switch
258
258
# otherwise, reshape added in post_optimize_action may not work correctly
@@ -409,7 +409,7 @@ def _add_handler(self, trans, node):
409
409
conv_inputs = [t_p .input [0 ], t_p .input [1 ], node .input [1 ]]
410
410
conv_node = self ._g .make_node (t_p .type , conv_inputs , attr = t_p .attr_onnx )
411
411
ops = self ._g .get_nodes ()
412
- trans .input [0 ] = utils .port_name (conv_node .name )
412
+ self . _g . replace_input ( trans , trans .input [0 ], utils .port_name (conv_node .name ), 0 )
413
413
self ._g .replace_all_inputs (ops , node .output [0 ], trans .output [0 ])
414
414
self ._g .remove_node (t_p .name )
415
415
self ._g .remove_node (node .name )
@@ -456,7 +456,7 @@ def _mul_handler(self, trans, node):
456
456
if not self ._switch_transpose_and_node (node , trans ):
457
457
return False
458
458
459
- node .input [input_index ] = multiplier_input_node .input [0 ]
459
+ self . _g . replace_input ( node , node .input [input_index ], multiplier_input_node .input [0 ], input_index )
460
460
self ._g .remove_node (multiplier_input_node .name )
461
461
return True
462
462
@@ -527,8 +527,9 @@ def _sum_handler(self, trans, node):
527
527
# switch to trans(sum(x1, x2, x3, ...))
528
528
ops = self ._g .get_nodes ()
529
529
self ._g .replace_all_inputs (ops , node .output [0 ], trans .output [0 ])
530
- node .input = [n .output [0 ] if n .is_const () else n .input [0 ] for n in inputs ]
531
- trans .input [0 ] = node .output [0 ]
530
+ new_input = [n .output [0 ] if n .is_const () else n .input [0 ] for n in inputs ]
531
+ self ._g .replace_inputs (node , new_input )
532
+ self ._g .replace_input (trans , trans .input [0 ], node .output [0 ], 0 )
532
533
533
534
# adjust shape if present
534
535
shape = self ._g .get_shape (node .output [0 ])
0 commit comments