@@ -845,37 +845,57 @@ def version_10(cls, ctx, node, **kwargs):
845
845
# mask begin
846
846
new_begin_mask = np .array (new_begin_mask , dtype = np_dtype )
847
847
if not np .all (new_begin_mask == 1 ):
848
- if begin .is_const ():
849
- begin = ctx .make_const (
850
- utils .make_name ("begin_masked" ),
851
- begin .get_tensor_value (as_list = False ) * new_begin_mask
852
- )
848
+ if begin .is_const () and strides .is_const ():
849
+ begin_vals = begin .get_tensor_value (as_list = False )
850
+ strides_vals = strides .get_tensor_value (as_list = False )
851
+ new_begin_vals = np .copy (begin_vals )
852
+ for i , v in enumerate (strides_vals ):
853
+ if v < 0 and new_begin_mask [i ] == 0 :
854
+ new_begin_vals [i ] = max_size
855
+ begin = ctx .make_const (utils .make_name ("begin_masked" ), new_begin_vals )
853
856
else :
854
- begin_mask_const = ctx .make_const (
855
- utils .make_name ("begin_mask" ),
856
- new_begin_mask
857
- )
858
- begin = ctx .make_node (
859
- "Mul" , [begin .output [0 ], begin_mask_const .output [0 ]],
860
- op_name_scope = node .name
861
- )
857
+ begin_mask_const = ctx .make_const (utils .make_name ("begin_mask" ), np .equal (new_begin_mask , 0 ))
858
+ zero_const = ctx .make_const (utils .make_name ("zero_const" ), np .zeros (1 , dtype = np_dtype ))
859
+ max_const = ctx .make_const (utils .make_name ("max_const" ), np .array (max_size , dtype = np_dtype ))
860
+ is_reverse_steps = ctx .make_node ("Less" , [strides .output [0 ], zero_const .output [0 ]],
861
+ op_name_scope = node .name )
862
+ is_reverse_and_full_range = ctx .make_node ("And" , [is_reverse_steps .output [0 ],
863
+ begin_mask_const .output [0 ]], op_name_scope = node .name )
864
+ begin = ctx .make_node ("Where" , [is_reverse_and_full_range .output [0 ], max_const .output [0 ],
865
+ begin .output [0 ]], op_name_scope = node .name )
866
+
862
867
# mask end
863
868
new_end_mask = np .array (new_end_mask , dtype = np_dtype )
864
869
end_output = end .output [0 ]
865
870
if not np .all (new_end_mask == min_size ):
866
- if end .is_const ():
867
- end = ctx .make_const (
868
- utils .make_name ("end_masked" ),
869
- np .maximum (end .get_tensor_value (as_list = False ), new_end_mask )
870
- )
871
+ if end .is_const () and strides .is_const () and False :
872
+ new_end_mask = np .maximum (end .get_tensor_value (as_list = False ), new_end_mask )
873
+ for i , v in enumerate (strides .get_tensor_value (as_list = False )):
874
+ if new_end_mask [i ] == max_size :
875
+ new_end_mask [i ] *= np .sign (v )
876
+ end = ctx .make_const (utils .make_name ("end_masked" ), new_end_mask )
871
877
end_output = end .output [0 ]
878
+
872
879
else :
873
- end_mask_const = ctx .make_const (
874
- utils .make_name ("end_mask" ),
875
- np .array (new_end_mask , dtype = np_dtype )
876
- )
877
- end_output = utils .make_name ("{}__end" .format (node .name ))
878
- math .make_min_or_max_op (ctx , "Max" , [end .output [0 ], end_mask_const .output [0 ]], [end_output ])
880
+ # Overlay new_end_mask with specified end values.
881
+ # Adjust max_size to min_size if steps are < 0
882
+ max_const = ctx .make_const (utils .make_name ("max_const" ), np .array (max_size , dtype = np_dtype ))
883
+ min_const = ctx .make_const (utils .make_name ("min_const" ), np .array (min_size , dtype = np_dtype ))
884
+ zero_const = ctx .make_const (utils .make_name ("zero_const" ), np .zeros (1 , dtype = np_dtype ))
885
+ end_mask_const = ctx .make_const (utils .make_name ("end_mask" ), np .array (new_end_mask , dtype = np_dtype ))
886
+ outputname = utils .make_name ("{}__newendmask" .format (node .name ))
887
+ new_end_mask = math .make_min_or_max_op (ctx , "Max" , [end .output [0 ], end_mask_const .output [0 ]],
888
+ [outputname ])
889
+ is_reverse_steps = ctx .make_node ("Less" , [strides .output [0 ], zero_const .output [0 ]],
890
+ op_name_scope = node .name )
891
+ is_full_range = ctx .make_node ("Equal" , [new_end_mask .output [0 ], max_const .output [0 ]],
892
+ op_name_scope = node .name )
893
+ is_reverse_and_full_range = ctx .make_node ("And" , [is_full_range .output [0 ], is_reverse_steps .output [0 ]],
894
+ op_name_scope = node .name )
895
+ final_end = ctx .make_node ("Where" , [is_reverse_and_full_range .output [0 ], min_const .output [0 ],
896
+ new_end_mask .output [0 ]], op_name_scope = node .name )
897
+ end_output = final_end .output [0 ]
898
+
879
899
# mask strides for shrink
880
900
shrink_strided_mask = np .array (shrink_strided_mask , dtype = np_dtype )
881
901
strides_output = strides .output [0 ]
0 commit comments