@@ -311,7 +311,27 @@ def _switch_transpose_and_node(self, node, trans, update_shape=True):
311
311
self ._g .set_shape (node .output [0 ], new_shape )
312
312
self ._g .set_shape (trans .output [0 ], shape )
313
313
return True
314
-
314
+ # this is for the case where node has multiple outputs. e.g. split node.
315
+ def _switch_transpose_and_node_with_multiple_outputs (self , node , trans , update_shape = True ):
316
+ input_index = self ._get_input_index_for_trans (node , trans )
317
+ for idx ,_output in enumerate (node .output ):
318
+ shape = self ._g .get_shape (_output )
319
+ nxt_nodes = self ._g .find_output_consumers (_output )
320
+ if idx == 0 :
321
+ transpose = trans
322
+ self ._g .replace_input (node , node .input [input_index ], transpose .input [0 ], input_index )
323
+ self ._g .replace_input (trans , trans .input [0 ], _output , 0 )
324
+ else :
325
+ transpose = self ._g .make_node ("Transpose" , [_output ], attr = {"perm" : trans .get_attr_value ("perm" )})
326
+ for nxt_node in nxt_nodes :
327
+ self ._g .replace_input (nxt_node , _output , transpose .output [0 ])
328
+
329
+ if update_shape and shape :
330
+ perm_inv = invert_perm (transpose .get_attr_value ("perm" ))
331
+ new_shape = [shape [i ] for i in perm_inv ]
332
+ self ._g .set_shape (_output , new_shape )
333
+ self ._g .set_shape (transpose .output [0 ], shape )
334
+ return True
315
335
# if return value is True, then it means Transpose is handled as designed
316
336
# otherwise, it means that we skip handling since it is not in our support set
317
337
def _handle_nhwc_tranpose (self , trans ):
@@ -694,6 +714,21 @@ def _split_handler(self, trans, node):
694
714
new_axes_const = self ._g .make_const (utils .make_name (node .inputs [1 ].name ), new_axes_np )
695
715
self ._g .replace_inputs (node , [node .input [0 ], new_axes_const .output [0 ]])
696
716
return True
717
+ # handling having branches
718
+ if len (node .output ) > 1 :
719
+ trans_rank = get_transpose_rank (trans )
720
+ axes = node .get_attr_value ("axis" , 0 )
721
+ perm = trans .get_attr ("perm" ).ints
722
+ axes = [axes + trans_rank if axes < 0 else axes ]
723
+ if split :
724
+ new_axes_np = np .array (split , dtype = np .int64 )
725
+ new_axes_const = self ._g .make_const (utils .make_name (node .inputs [1 ].name ), new_axes_np )
726
+ # [Transpose -> Split -> next_nodes] -> [Split -> Transpose -> next_nodes]
727
+ if not self ._switch_transpose_and_node_with_multiple_outputs (node , trans , 1 ):
728
+ return False
729
+ new_axes = [perm [a ] for a in axes ]
730
+ node .set_attr ("axes" , new_axes )
731
+ return True
697
732
return False
698
733
699
734
def _unsqueeze_handler (self , trans , node ):
0 commit comments