@@ -845,37 +845,51 @@ def version_10(cls, ctx, node, **kwargs):
845
845
# mask begin
846
846
new_begin_mask = np .array (new_begin_mask , dtype = np_dtype )
847
847
if not np .all (new_begin_mask == 1 ):
848
- if begin .is_const ():
849
- begin = ctx .make_const (
850
- utils .make_name ("begin_masked" ),
851
- begin .get_tensor_value (as_list = False ) * new_begin_mask
852
- )
848
+ if begin .is_const () and strides .is_const ():
849
+ new_begin_vals = np .copy (begin .get_tensor_value (as_list = False ))
850
+ strides_vals = strides .get_tensor_value (as_list = False )
851
+ idx1 = np .where (new_begin_mask == 0 )
852
+ idx2 = np .where (strides_vals < 0 )
853
+ idx3 = np .intersect1d (idx1 , idx2 )
854
+ new_begin_vals [idx3 ] = max_size
855
+ begin = ctx .make_const (utils .make_name ("begin_masked" ), new_begin_vals )
853
856
else :
854
- begin_mask_const = ctx .make_const (
855
- utils .make_name ("begin_mask " ),
856
- new_begin_mask
857
- )
858
- begin = ctx .make_node (
859
- "Mul " , [begin .output [0 ], begin_mask_const .output [0 ]],
860
- op_name_scope = node .name
861
- )
857
+ begin_mask_const = ctx .make_const (utils . make_name ( "begin_mask" ), np . equal ( new_begin_mask , 0 ))
858
+ zero_const = ctx . make_const ( utils .make_name ("zero_const " ), np . zeros ( 1 , dtype = np_dtype ))
859
+ max_const = ctx . make_const ( utils . make_name ( "max_const" ), np . array ( max_size , dtype = np_dtype ))
860
+ op1 = ctx . make_node ( "Less" , [ strides . output [ 0 ], zero_const . output [ 0 ]], op_name_scope = node . name )
861
+ op2 = ctx .make_node ("And" , [ op1 . output [ 0 ], begin_mask_const . output [ 0 ]], op_name_scope = node . name )
862
+ begin = ctx . make_node ( "Where " , [op2 .output [0 ], max_const . output [ 0 ], begin .output [0 ]],
863
+ op_name_scope = node .name )
864
+
862
865
# mask end
863
866
new_end_mask = np .array (new_end_mask , dtype = np_dtype )
864
867
end_output = end .output [0 ]
865
868
if not np .all (new_end_mask == min_size ):
866
- if end .is_const ():
867
- end = ctx .make_const (
868
- utils .make_name ("end_masked" ),
869
- np .maximum (end .get_tensor_value (as_list = False ), new_end_mask )
870
- )
869
+ if end .is_const () and strides .is_const ():
870
+ new_end_mask = np .maximum (end .get_tensor_value (as_list = False ), new_end_mask )
871
+ idx = np .where (new_end_mask == max_size )
872
+ sign = np .sign (strides .get_tensor_value (as_list = False ))[idx ]
873
+ new_end_mask [idx ] = new_end_mask [idx ] * sign
874
+ end = ctx .make_const (utils .make_name ("end_masked" ), new_end_mask )
871
875
end_output = end .output [0 ]
872
876
else :
873
- end_mask_const = ctx .make_const (
874
- utils .make_name ("end_mask" ),
875
- np .array (new_end_mask , dtype = np_dtype )
876
- )
877
- end_output = utils .make_name ("{}__end" .format (node .name ))
878
- math .make_min_or_max_op (ctx , "Max" , [end .output [0 ], end_mask_const .output [0 ]], [end_output ])
877
+ # Overlay new_end_mask with specified end values.
878
+ # Adjust max_size to min_size if steps are < 0
879
+ max_const = ctx .make_const (utils .make_name ("max_const" ), np .array (max_size , dtype = np_dtype ))
880
+ min_const = ctx .make_const (utils .make_name ("min_const" ), np .array (min_size , dtype = np_dtype ))
881
+ zero_const = ctx .make_const (utils .make_name ("zero_const" ), np .zeros (1 , dtype = np_dtype ))
882
+ end_mask_const = ctx .make_const (utils .make_name ("end_mask" ), np .array (new_end_mask , dtype = np_dtype ))
883
+ outputname = utils .make_name ("{}__newendmask" .format (node .name ))
884
+ new_end_mask = math .make_min_or_max_op (ctx , "Max" , [end .output [0 ], end_mask_const .output [0 ]],
885
+ [outputname ])
886
+ op1 = ctx .make_node ("Less" , [strides .output [0 ], zero_const .output [0 ]], op_name_scope = node .name )
887
+ op2 = ctx .make_node ("Equal" , [new_end_mask .output [0 ], max_const .output [0 ]], op_name_scope = node .name )
888
+ op3 = ctx .make_node ("And" , [op2 .output [0 ], op1 .output [0 ]], op_name_scope = node .name )
889
+ final_end = ctx .make_node ("Where" , [op3 .output [0 ], min_const .output [0 ],
890
+ new_end_mask .output [0 ]], op_name_scope = node .name )
891
+ end_output = final_end .output [0 ]
892
+
879
893
# mask strides for shrink
880
894
shrink_strided_mask = np .array (shrink_strided_mask , dtype = np_dtype )
881
895
strides_output = strides .output [0 ]
0 commit comments