Skip to content

Commit 5617517

Browse files
committed
fix bug in transpose optimizer
1 parent 2ce73d4 commit 5617517

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,13 @@ def _transpose_handler(self, trans, node):
340340
ops = self._g.get_nodes()
341341
self._g.replace_all_inputs(ops, node.output[0], trans.input[0])
342342

343+
shape = self._g.get_shape(node.output[0])
344+
dtype = self._g.get_dtype(node.output[0])
343345
self._g.remove_node(trans.name)
344346
self._g.remove_node(node.name)
347+
if node.output[0] in self._g.outputs:
348+
self._g.make_node("Identity", [trans.input[0]],
349+
outputs=node.output, shapes=[shape], dtypes=[dtype])
345350
return True
346351
return False
347352

@@ -408,6 +413,8 @@ def _mul_handler(self, trans, node):
408413
return False
409414

410415
def _identity_handler(self, trans, node):
416+
if node.output[0] in self._g.outputs:
417+
return False
411418
ops = self._g.get_nodes()
412419
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
413420
self._g.remove_node(node.name)

0 commit comments

Comments
 (0)