@@ -131,6 +131,9 @@ def merge_duplicated_transposes(self):
131
131
graph .delete_unused_nodes (graph .outputs )
132
132
133
133
def _optimize (self , graph ):
134
+ return self ._apply_optimization (graph , self ._optimize_at_current_graph_level )
135
+
136
+ def _optimize_at_current_graph_level (self , graph ):
134
137
self ._g = graph
135
138
self .pre_optimize_action ()
136
139
no_action = False
@@ -190,7 +193,7 @@ def _handle_node_having_branches(self, node):
190
193
191
194
# make sure node's all input transpose all have only 1 consumer node,
192
195
# otherwise, it would impact their other output nodes
193
- if self ._transpose_has_single_consumer_node (node .inputs ):
196
+ if self ._nodes_has_single_consumer_node (node .inputs ):
194
197
self ._create_transpose_pairs_after_node (node )
195
198
input_transposes = node .inputs
196
199
for n in input_transposes :
@@ -226,7 +229,7 @@ def _get_input_index_for_trans(self, node, trans):
226
229
227
230
# the assumption is: both node and trans have only 1 output
228
231
def _switch_transpose_and_node (self , node , trans ):
229
- if not self ._transpose_has_single_consumer_node ([trans ]):
232
+ if not self ._nodes_has_single_consumer_node ([trans ]):
230
233
return False
231
234
232
235
input_index = self ._get_input_index_for_trans (node , trans )
@@ -274,13 +277,12 @@ def _remove_useless_tranpose(self, trans):
274
277
self ._g .replace_all_inputs (self ._g .get_nodes (), trans .output [0 ], trans .input [0 ])
275
278
self ._g .remove_node (trans .name )
276
279
277
- def _transpose_has_single_consumer_node (self , trans_nodes ):
278
- result = True
279
- for n in trans_nodes :
280
- cnt = len (set (self ._g .find_output_consumers (n .output [0 ])))
281
- result = result and cnt == 1
282
- if not result :
283
- return False
280
+ def _nodes_has_single_consumer_node (self , nodes ):
281
+ for n in nodes :
282
+ for output in n .output :
283
+ cnt = len (set (self ._g .find_output_consumers (output )))
284
+ if cnt != 1 :
285
+ return False
284
286
return True
285
287
286
288
def _get_non_nchw_transpose_output_nodes (self , node ):
@@ -472,7 +474,7 @@ def _simple_through_handler(self, trans, node):
472
474
473
475
def _shape_handler (self , trans , node ):
474
476
# input > trans > shape can be changed into input > shape > gather
475
- if not self ._transpose_has_single_consumer_node ([trans ]):
477
+ if not self ._nodes_has_single_consumer_node ([trans ]):
476
478
return False
477
479
478
480
output_shape = self ._g .get_shape (node .output [0 ])
0 commit comments