@@ -193,7 +193,7 @@ def _handle_node_having_branches(self, node):
193
193
194
194
# make sure node's all input transpose all have only 1 consumer node,
195
195
# otherwise, it would impact their other output nodes
196
- if self ._transpose_has_single_consumer_node (node .inputs ):
196
+ if self ._nodes_has_single_consumer_node (node .inputs ):
197
197
self ._create_transpose_pairs_after_node (node )
198
198
input_transposes = node .inputs
199
199
for n in input_transposes :
@@ -229,7 +229,7 @@ def _get_input_index_for_trans(self, node, trans):
229
229
230
230
# the assumption is: both node and trans have only 1 output
231
231
def _switch_transpose_and_node (self , node , trans ):
232
- if not self ._transpose_has_single_consumer_node ([trans ]):
232
+ if not self ._nodes_has_single_consumer_node ([trans ]):
233
233
return False
234
234
235
235
input_index = self ._get_input_index_for_trans (node , trans )
@@ -277,13 +277,12 @@ def _remove_useless_tranpose(self, trans):
277
277
self ._g .replace_all_inputs (self ._g .get_nodes (), trans .output [0 ], trans .input [0 ])
278
278
self ._g .remove_node (trans .name )
279
279
280
- def _transpose_has_single_consumer_node (self , trans_nodes ):
281
- result = True
282
- for n in trans_nodes :
283
- cnt = len (set (self ._g .find_output_consumers (n .output [0 ])))
284
- result = result and cnt == 1
285
- if not result :
286
- return False
280
+ def _nodes_has_single_consumer_node (self , nodes ):
281
+ for n in nodes :
282
+ for output in n .output :
283
+ cnt = len (set (self ._g .find_output_consumers (output )))
284
+ if cnt != 1 :
285
+ return False
287
286
return True
288
287
289
288
def _get_non_nchw_transpose_output_nodes (self , node ):
@@ -475,7 +474,7 @@ def _simple_through_handler(self, trans, node):
475
474
476
475
def _shape_handler (self , trans , node ):
477
476
# input > trans > shape can be changed into input > shape > gather
478
- if not self ._transpose_has_single_consumer_node ([trans ]):
477
+ if not self ._nodes_has_single_consumer_node ([trans ]):
479
478
return False
480
479
481
480
output_shape = self ._g .get_shape (node .output [0 ])
0 commit comments