Skip to content

Commit 8d5984b

Browse files
committed
Add replace_input.
1 parent c99dcb2 commit 8d5984b

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@ def _switch_transpose_and_node(self, node, trans):
251251

252252
ops = self._g.get_nodes()
253253
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
254-
node.input[input_index] = trans.input[0]
255-
trans.input[0] = node.output[0]
254+
self._g.replace_input(node, node.input[input_index], trans.input[0], input_index)
255+
self._g.replace_input(trans, trans.input[0], node.output[0], 0)
256256

257257
# need to transpose node shape in backward direction as well after switch
258258
# otherwise, reshape added in post_optimize_action may not work correctly
@@ -409,7 +409,7 @@ def _add_handler(self, trans, node):
409409
conv_inputs = [t_p.input[0], t_p.input[1], node.input[1]]
410410
conv_node = self._g.make_node(t_p.type, conv_inputs, attr=t_p.attr_onnx)
411411
ops = self._g.get_nodes()
412-
trans.input[0] = utils.port_name(conv_node.name)
412+
self._g.replace_input(trans, trans.input[0], utils.port_name(conv_node.name), 0)
413413
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
414414
self._g.remove_node(t_p.name)
415415
self._g.remove_node(node.name)
@@ -456,7 +456,7 @@ def _mul_handler(self, trans, node):
456456
if not self._switch_transpose_and_node(node, trans):
457457
return False
458458

459-
node.input[input_index] = multiplier_input_node.input[0]
459+
self._g.replace_input(node, node.input[input_index], multiplier_input_node.input[0], input_index)
460460
self._g.remove_node(multiplier_input_node.name)
461461
return True
462462

@@ -527,8 +527,9 @@ def _sum_handler(self, trans, node):
527527
# switch to trans(sum(x1, x2, x3, ...))
528528
ops = self._g.get_nodes()
529529
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
530-
node.input = [n.output[0] if n.is_const() else n.input[0] for n in inputs]
531-
trans.input[0] = node.output[0]
530+
new_input = [n.output[0] if n.is_const() else n.input[0] for n in inputs]
531+
self._g.replace_inputs(node, new_input)
532+
self._g.replace_input(trans, trans.input[0], node.output[0], 0)
532533

533534
# adjust shape if present
534535
shape = self._g.get_shape(node.output[0])

0 commit comments

Comments
 (0)