Skip to content

Commit d03b003

Browse files
committed
fix bug
1 parent eae4272 commit d03b003

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,17 @@ def pre_optimize_action(self):
6060
target_t = reshape_op.inputs[0].get_tensor_value(as_list=False)
6161
target_shape = reshape_op.inputs[1].get_tensor_value(as_list=False)
6262
new_data = np.reshape(target_t, tuple(target_shape))
63-
const_name = utils.port_name(utils.make_name("Const"))
63+
const_name = reshape_op.output[0]
64+
self._g.remove_node(reshape_op.name)
65+
self._g.make_const(const_name, new_data)
6466

6567
# point all children nodes inputs to the new node
6668
for output_name in reshape_op.output:
6769
for child in ops:
6870
for i, name in enumerate(child.input):
6971
if name == output_name:
7072
child.input[i] = const_name
71-
self._g.make_const(const_name, new_data)
72-
self._g.remove_node(reshape_op.name)
73+
7374
self._g.topological_sort(self._g.get_nodes())
7475

7576
def post_optimize_action(self):

0 commit comments

Comments
 (0)