@@ -896,15 +896,15 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
896
896
# @int shrink_axis_mask, @int new_axis_mask)
897
897
# T output = Slice(T input, Tind starts, Tind ends, Tind axes, Tind steps)
898
898
# "ends" are exclusive, "axes" and "steps" are optional, their default val are [0, ...] and 1
899
- input_x = node .inputs [0 ]
900
- begin = node .inputs [1 ]
901
- end = node .inputs [2 ]
902
- strides = node .inputs [3 ]
899
+ input_x = node .input [0 ]
900
+ begin = node .input [1 ]
901
+ end = node .input [2 ]
902
+ strides = node .input [3 ]
903
903
new_axis_mask = node .get_attr ("new_axis_mask" )
904
904
new_axis_mask = new_axis_mask .i if new_axis_mask is not None else 0
905
905
906
- if begin .is_const () and end .is_const () and strides .is_const () \
907
- and all (val == 1 for val in strides .get_tensor_value ()) \
906
+ if ctx .is_const (begin ) and ctx .is_const (end ) and ctx .is_const (strides ) \
907
+ and all (val == 1 for val in ctx .get_tensor_value (strides )) \
908
908
and new_axis_mask == 0 :
909
909
cls .version_1 (ctx , node , ** kwargs )
910
910
return
@@ -945,7 +945,7 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
945
945
if (new_axis_mask >> bit ) & 1 == 1 :
946
946
num_new += 1
947
947
if (ellipsis_mask >> bit ) & 1 :
948
- input_shape = ctx .get_shape (input_x . output [ 0 ] )
948
+ input_shape = ctx .get_shape (input_x )
949
949
# calculate what rank for ellipsis: input rank - (being rank - all new_axis - 1)
950
950
ellipsis_gap = len (input_shape ) - param_rank + num_new + 1
951
951
if (new_axis_mask >> bit ) & 1 == 1 :
@@ -954,7 +954,7 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
954
954
end_mask |= 1 << bit
955
955
956
956
input_x = GraphBuilder (ctx ).make_unsqueeze (
957
- {'data' : input_x . output [ 0 ] , 'axes' : unqueeze_at }, return_node = True )
957
+ {'data' : input_x , 'axes' : unqueeze_at })
958
958
959
959
960
960
# use in onnx graph to mask begin
@@ -969,7 +969,7 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
969
969
ellipsis_gap = 0
970
970
for idx in range (param_rank ):
971
971
if (ellipsis_mask >> idx ) & 1 :
972
- input_shape = ctx .get_shape (input_x . output [ 0 ] )
972
+ input_shape = ctx .get_shape (input_x )
973
973
utils .make_sure (
974
974
input_shape is not None ,
975
975
"StridedSlice op {} requires the shape of input" .format (node .name )
@@ -1006,34 +1006,32 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
1006
1006
# mask begin
1007
1007
new_begin_mask = np .array (new_begin_mask , dtype = np_dtype )
1008
1008
if not np .all (new_begin_mask == 1 ):
1009
- if begin .is_const () and strides .is_const ():
1010
- new_begin_vals = np .copy (begin .get_tensor_value (as_list = False ))
1011
- strides_vals = strides .get_tensor_value (as_list = False )
1009
+ if ctx .is_const (begin ) and ctx .is_const (strides ):
1010
+ new_begin_vals = np .copy (ctx .get_tensor_value (begin , as_list = False ))
1011
+ strides_vals = ctx .get_tensor_value (strides , as_list = False )
1012
1012
idx1 = np .where (new_begin_mask == 0 )
1013
1013
idx2 = np .where (strides_vals < 0 )
1014
1014
idx3 = np .intersect1d (idx1 , idx2 )
1015
1015
new_begin_vals [idx3 ] = max_size
1016
- begin = ctx .make_const (utils .make_name ("begin_masked" ), new_begin_vals )
1016
+ begin = ctx .make_const (utils .make_name ("begin_masked" ), new_begin_vals ). output [ 0 ]
1017
1017
else :
1018
1018
begin_mask_const = ctx .make_const (utils .make_name ("begin_mask" ), np .equal (new_begin_mask , 0 ))
1019
1019
zero_const = ctx .make_const (utils .make_name ("zero_const" ), np .zeros (1 , dtype = np_dtype ))
1020
1020
max_const = ctx .make_const (utils .make_name ("max_const" ), np .array (max_size , dtype = np_dtype ))
1021
- op1 = ctx .make_node ("Less" , [strides . output [ 0 ] , zero_const .output [0 ]], op_name_scope = node .name )
1021
+ op1 = ctx .make_node ("Less" , [strides , zero_const .output [0 ]], op_name_scope = node .name )
1022
1022
op2 = ctx .make_node ("And" , [op1 .output [0 ], begin_mask_const .output [0 ]], op_name_scope = node .name )
1023
- begin = ctx .make_node ("Where" , [op2 .output [0 ], max_const .output [0 ], begin . output [ 0 ] ],
1024
- op_name_scope = node .name )
1023
+ begin = ctx .make_node ("Where" , [op2 .output [0 ], max_const .output [0 ], begin ],
1024
+ op_name_scope = node .name ). output [ 0 ]
1025
1025
1026
1026
# mask end
1027
1027
new_end_mask = np .array (new_end_mask , dtype = np_dtype )
1028
- end_output = end .output [0 ]
1029
1028
if not np .all (new_end_mask == min_size ):
1030
- if end .is_const () and strides .is_const ():
1031
- new_end_mask = np .maximum (end .get_tensor_value (as_list = False ), new_end_mask )
1029
+ if ctx .is_const (end ) and ctx .is_const (strides ):
1030
+ new_end_mask = np .maximum (ctx .get_tensor_value (end , as_list = False ), new_end_mask )
1032
1031
idx = np .where (new_end_mask == max_size )
1033
- sign = np .sign (strides .get_tensor_value (as_list = False ))[idx ]
1032
+ sign = np .sign (ctx .get_tensor_value (strides , as_list = False ))[idx ]
1034
1033
new_end_mask [idx ] = new_end_mask [idx ] * sign
1035
- end = ctx .make_const (utils .make_name ("end_masked" ), new_end_mask )
1036
- end_output = end .output [0 ]
1034
+ end = ctx .make_const (utils .make_name ("end_masked" ), new_end_mask ).output [0 ]
1037
1035
else :
1038
1036
# Overlay new_end_mask with specified end values.
1039
1037
# Adjust max_size to min_size if steps are < 0
@@ -1042,25 +1040,22 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
1042
1040
zero_const = ctx .make_const (utils .make_name ("zero_const" ), np .zeros (1 , dtype = np_dtype ))
1043
1041
end_mask_const = ctx .make_const (utils .make_name ("end_mask" ), np .array (new_end_mask , dtype = np_dtype ))
1044
1042
outputname = utils .make_name ("{}__newendmask" .format (node .name ))
1045
- new_end_mask = math .make_min_or_max_op (ctx , "Max" , [end . output [ 0 ] , end_mask_const .output [0 ]],
1043
+ new_end_mask = math .make_min_or_max_op (ctx , "Max" , [end , end_mask_const .output [0 ]],
1046
1044
[outputname ])
1047
- op1 = ctx .make_node ("Less" , [strides . output [ 0 ] , zero_const .output [0 ]], op_name_scope = node .name )
1045
+ op1 = ctx .make_node ("Less" , [strides , zero_const .output [0 ]], op_name_scope = node .name )
1048
1046
op2 = ctx .make_node ("Equal" , [new_end_mask .output [0 ], max_const .output [0 ]], op_name_scope = node .name )
1049
1047
op3 = ctx .make_node ("And" , [op2 .output [0 ], op1 .output [0 ]], op_name_scope = node .name )
1050
- final_end = ctx .make_node ("Where" , [op3 .output [0 ], min_const .output [0 ],
1051
- new_end_mask .output [0 ]], op_name_scope = node .name )
1052
- end_output = final_end .output [0 ]
1048
+ end = ctx .make_node ("Where" , [op3 .output [0 ], min_const .output [0 ], new_end_mask .output [0 ]],
1049
+ op_name_scope = node .name ).output [0 ]
1053
1050
1054
1051
# mask strides for shrink
1055
1052
shrink_strided_mask = np .array (shrink_strided_mask , dtype = np_dtype )
1056
- strides_output = strides .output [0 ]
1057
1053
if not np .all (shrink_strided_mask == min_size ):
1058
- if strides .is_const ():
1054
+ if ctx .is_const (strides ):
1059
1055
strides = ctx .make_const (
1060
1056
utils .make_name ("strides_masked" ),
1061
- np .maximum (strides .get_tensor_value (as_list = False ), shrink_strided_mask )
1062
- )
1063
- strides_output = strides .output [0 ]
1057
+ np .maximum (ctx .get_tensor_value (strides , as_list = False ), shrink_strided_mask )
1058
+ ).output [0 ]
1064
1059
else :
1065
1060
shrink_strided_mask_const = ctx .make_const (
1066
1061
utils .make_name ("strides_mask" ),
@@ -1069,9 +1064,10 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
1069
1064
strides_output = utils .make_name ("{}__strides" .format (node .name ))
1070
1065
math .make_min_or_max_op (
1071
1066
ctx , "Max" ,
1072
- [strides . output [ 0 ] , shrink_strided_mask_const .output [0 ]],
1067
+ [strides , shrink_strided_mask_const .output [0 ]],
1073
1068
[strides_output ]
1074
1069
)
1070
+ strides = strides_output
1075
1071
# create axes input
1076
1072
axes_const = ctx .make_const (
1077
1073
utils .make_name ("slice_axes" ),
@@ -1080,10 +1076,10 @@ def any_version_after10(cls, opset, ctx, node, **kwargs):
1080
1076
axes_output = axes_const .output [0 ]
1081
1077
1082
1078
inputs_map = {
1083
- "data" : input_x . output [ 0 ] ,
1084
- "starts" : begin . output [ 0 ] ,
1085
- "ends" : end_output ,
1086
- "steps" : strides_output ,
1079
+ "data" : input_x ,
1080
+ "starts" : begin ,
1081
+ "ends" : end ,
1082
+ "steps" : strides ,
1087
1083
"axes" : axes_output
1088
1084
}
1089
1085
kwargs = {** inputs_map , "outputs" : node .output }
0 commit comments