@@ -552,15 +552,23 @@ def _slice_handler(self, trans, node):
552
552
axes = None
553
553
if self ._g .opset < 10 :
554
554
axes = node .get_attr ("axes" ).ints
555
+ if axes == [0 , 1 , 2 , 3 ]:
556
+ node .set_attr ("axes" , NCHW_TO_NHWC )
557
+ return self ._switch_transpose_and_node (node , trans )
555
558
else : # in opset 10, axes is input instead of an attribute.
556
- if len (node .inputs ) >= 4 :
557
- axes_node = node .inputs [3 ]
558
- if axes_node .is_const ():
559
- axes = axes_node .get_tensor_value (as_list = True )
560
-
561
- if axes == [0 , 1 , 2 , 3 ]:
562
- node .set_attr ("axes" , NCHW_TO_NHWC )
563
- return self ._switch_transpose_and_node (node , trans )
559
+ if len (node .inputs ) >= 4 and node .inputs [3 ].is_const ():
560
+ axes = node .inputs [3 ].get_tensor_value (as_list = True )
561
+ if axes == [0 , 1 , 2 , 3 ]:
562
+ # axes node might be shared
563
+ new_axes = np .array (NCHW_TO_NHWC , dtype = np .int64 )
564
+ if self ._nodes_has_single_consumer_node ([node ]):
565
+ node .inputs [3 ].set_tensor_value (new_axes )
566
+ else :
567
+ new_axes_const = self ._g .make_const (
568
+ utils .make_name (node .inputs [3 ].name ), new_axes
569
+ )
570
+ self ._g .replace_input (node , node .input [3 ], new_axes_const .output [0 ])
571
+ return self ._switch_transpose_and_node (node , trans )
564
572
return False
565
573
566
574
def _simple_through_handler (self , trans , node ):
0 commit comments