Skip to content

Commit 1ee0bed

Browse files
committed
generalize function "_transpose_has_single_consumer_node"
1 parent 0a87482 commit 1ee0bed

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def _handle_node_having_branches(self, node):
193193

194194
# make sure node's all input transpose all have only 1 consumer node,
195195
# otherwise, it would impact their other output nodes
196-
if self._transpose_has_single_consumer_node(node.inputs):
196+
if self._nodes_has_single_consumer_node(node.inputs):
197197
self._create_transpose_pairs_after_node(node)
198198
input_transposes = node.inputs
199199
for n in input_transposes:
@@ -229,7 +229,7 @@ def _get_input_index_for_trans(self, node, trans):
229229

230230
# the assumption is: both node and trans have only 1 output
231231
def _switch_transpose_and_node(self, node, trans):
232-
if not self._transpose_has_single_consumer_node([trans]):
232+
if not self._nodes_has_single_consumer_node([trans]):
233233
return False
234234

235235
input_index = self._get_input_index_for_trans(node, trans)
@@ -277,13 +277,12 @@ def _remove_useless_tranpose(self, trans):
277277
self._g.replace_all_inputs(self._g.get_nodes(), trans.output[0], trans.input[0])
278278
self._g.remove_node(trans.name)
279279

280-
def _transpose_has_single_consumer_node(self, trans_nodes):
281-
result = True
282-
for n in trans_nodes:
283-
cnt = len(set(self._g.find_output_consumers(n.output[0])))
284-
result = result and cnt == 1
285-
if not result:
286-
return False
280+
def _nodes_has_single_consumer_node(self, nodes):
281+
for n in nodes:
282+
for output in n.output:
283+
cnt = len(set(self._g.find_output_consumers(output)))
284+
if cnt != 1:
285+
return False
287286
return True
288287

289288
def _get_non_nchw_transpose_output_nodes(self, node):
@@ -475,7 +474,7 @@ def _simple_through_handler(self, trans, node):
475474

476475
def _shape_handler(self, trans, node):
477476
# input > trans > shape can be changed into input > shape > gather
478-
if not self._transpose_has_single_consumer_node([trans]):
477+
if not self._nodes_has_single_consumer_node([trans]):
479478
return False
480479

481480
output_shape = self._g.get_shape(node.output[0])

0 commit comments

Comments
 (0)