Skip to content

Commit 07dfaa8

Browse files
authored
Add a condiction when we try to remove a transpose node. (#2272)
Signed-off-by: Jay Zhang <[email protected]>
1 parent ae4c39e commit 07dfaa8

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,10 +509,14 @@ def _add_handler(self, trans, node):
509509
return True
510510
return self._handle_node_having_branches(trans, node)
511511

512+
def _output_node_has_single_consumer_node(self, node):
513+
output_node = self._g.get_node_by_name(node.output[0])
514+
return output_node and output_node.output and self._nodes_has_single_consumer_node([output_node])
515+
512516
def _transpose_handler(self, trans, node):
513517
perm = trans.get_attr_value("perm")
514518
perm_inv = invert_perm(perm)
515-
if is_tranpose_of_type(node, perm_inv):
519+
if is_tranpose_of_type(node, perm_inv) and self._output_node_has_single_consumer_node(node):
516520
for g in {self._g, node.graph}:
517521
g.replace_all_inputs(node.output[0], trans.input[0]) # ops=g.get_nodes()
518522

tf2onnx/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33

44
version = '1.15.1'
5-
git_version = 'dc6155b52a137d858456fcc6bc720c327eec5612'
5+
git_version = 'ae4c39ed3bdab7edf487d73d5892a573684d1d6a'

0 commit comments

Comments
 (0)