@@ -738,19 +738,19 @@ def version_10(cls, ctx, node, **kwargs):
738
738
# @int shrink_axis_mask, @int new_axis_mask)
739
739
# T output = Slice(T input, Tind starts, Tind ends, Tind axes, Tind steps)
740
740
# "ends" are exclusive, "axes" and "steps" are optional, their default val are [0, ...] and 1
741
+ input_x = node .inputs [0 ]
741
742
begin = node .inputs [1 ]
742
743
end = node .inputs [2 ]
743
744
strides = node .inputs [3 ]
745
+ new_axis_mask = node .get_attr ("new_axis_mask" )
746
+ new_axis_mask = new_axis_mask .i if new_axis_mask is not None else 0
747
+
744
748
if begin .is_const () and end .is_const () and strides .is_const () \
745
- and all (val == 1 for val in strides .get_tensor_value ()):
749
+ and all (val == 1 for val in strides .get_tensor_value ()) \
750
+ and new_axis_mask == 0 :
746
751
cls .version_1 (ctx , node , ** kwargs )
747
752
return
748
753
749
- not_supported_attr = ["new_axis_mask" ]
750
- for attr_name in not_supported_attr :
751
- attr = node .get_attr (attr_name )
752
- if attr is not None and attr .i != 0 :
753
- raise ValueError ("StridedSlice: attribute " + attr_name + " not supported" )
754
754
onnx_dtype = ctx .get_dtype (node .input [1 ])
755
755
np_dtype = utils .ONNX_TO_NUMPY_DTYPE [onnx_dtype ]
756
756
@@ -769,6 +769,15 @@ def version_10(cls, ctx, node, **kwargs):
769
769
ellipsis_mask = ellipsis_mask .i if ellipsis_mask is not None else 0
770
770
shrink_axis_mask = node .get_attr ("shrink_axis_mask" )
771
771
shrink_axis_mask = shrink_axis_mask .i if shrink_axis_mask is not None else 0
772
+ if new_axis_mask != 0 :
773
+ unqueeze_at = []
774
+ for bit in range (32 ):
775
+ if (new_axis_mask >> bit ) & 1 == 1 :
776
+ unqueeze_at .append (bit )
777
+ begin_mask |= 1 << bit
778
+ end_mask |= 1 << bit
779
+ input_x = ctx .make_node ("Unsqueeze" , [input_x .output [0 ]], {"axes" : unqueeze_at })
780
+
772
781
param_shape = ctx .get_shape (node .input [1 ]) or \
773
782
ctx .get_shape (node .input [2 ]) or \
774
783
ctx .get_shape (node .input [3 ])
@@ -789,7 +798,7 @@ def version_10(cls, ctx, node, **kwargs):
789
798
ellipsis_gap = 0
790
799
for idx in range (param_rank ):
791
800
if (ellipsis_mask >> idx ) & 1 :
792
- input_shape = ctx .get_shape (node . input [0 ])
801
+ input_shape = ctx .get_shape (input_x . output [0 ])
793
802
utils .make_sure (
794
803
input_shape is not None ,
795
804
"StridedSlice op {} requires the shape of input" .format (node .name )
@@ -886,7 +895,7 @@ def version_10(cls, ctx, node, **kwargs):
886
895
axes_output = axes_const .output [0 ]
887
896
888
897
inputs_map = {
889
- "data" : node . input [0 ],
898
+ "data" : input_x . output [0 ],
890
899
"starts" : begin .output [0 ],
891
900
"ends" : end_output ,
892
901
"steps" : strides_output ,
0 commit comments