@@ -847,33 +847,32 @@ def version_10(cls, ctx, node, **kwargs):
847
847
if not np .all (new_begin_mask == 1 ):
848
848
if begin .is_const () and strides .is_const ():
849
849
new_begin_vals = np .copy (begin .get_tensor_value (as_list = False ))
850
- for i , v in enumerate (strides .get_tensor_value (as_list = False )):
851
- if v < 0 and new_begin_mask [i ] == 0 :
852
- new_begin_vals [i ] = max_size
850
+ strides_vals = strides .get_tensor_value (as_list = False )
851
+ idx1 = np .where (new_begin_mask == 0 )
852
+ idx2 = np .where (strides_vals < 0 )
853
+ idx3 = np .intersect1d (idx1 , idx2 )
854
+ new_begin_vals [idx3 ] = max_size
853
855
begin = ctx .make_const (utils .make_name ("begin_masked" ), new_begin_vals )
854
856
else :
855
857
begin_mask_const = ctx .make_const (utils .make_name ("begin_mask" ), np .equal (new_begin_mask , 0 ))
856
858
zero_const = ctx .make_const (utils .make_name ("zero_const" ), np .zeros (1 , dtype = np_dtype ))
857
859
max_const = ctx .make_const (utils .make_name ("max_const" ), np .array (max_size , dtype = np_dtype ))
858
- is_reverse_steps = ctx .make_node ("Less" , [strides .output [0 ], zero_const .output [0 ]],
859
- op_name_scope = node .name )
860
- is_reverse_and_full_range = ctx .make_node ("And" , [is_reverse_steps .output [0 ],
861
- begin_mask_const .output [0 ]], op_name_scope = node .name )
862
- begin = ctx .make_node ("Where" , [is_reverse_and_full_range .output [0 ], max_const .output [0 ],
863
- begin .output [0 ]], op_name_scope = node .name )
860
+ op1 = ctx .make_node ("Less" , [strides .output [0 ], zero_const .output [0 ]], op_name_scope = node .name )
861
+ op2 = ctx .make_node ("And" , [op1 .output [0 ], begin_mask_const .output [0 ]], op_name_scope = node .name )
862
+ begin = ctx .make_node ("Where" , [op2 .output [0 ], max_const .output [0 ], begin .output [0 ]],
863
+ op_name_scope = node .name )
864
864
865
865
# mask end
866
866
new_end_mask = np .array (new_end_mask , dtype = np_dtype )
867
867
end_output = end .output [0 ]
868
868
if not np .all (new_end_mask == min_size ):
869
- if end .is_const () and strides .is_const () and False :
869
+ if end .is_const () and strides .is_const ():
870
870
new_end_mask = np .maximum (end .get_tensor_value (as_list = False ), new_end_mask )
871
- for i , v in enumerate ( strides . get_tensor_value ( as_list = False )):
872
- if new_end_mask [ i ] == max_size :
873
- new_end_mask [i ] *= np . sign ( v )
871
+ idx = np . where ( new_end_mask == max_size )
872
+ sign = np . sign ( strides . get_tensor_value ( as_list = False ))[ idx ]
873
+ new_end_mask [ idx ] = new_end_mask [idx ] * sign
874
874
end = ctx .make_const (utils .make_name ("end_masked" ), new_end_mask )
875
875
end_output = end .output [0 ]
876
-
877
876
else :
878
877
# Overlay new_end_mask with specified end values.
879
878
# Adjust max_size to min_size if steps are < 0
@@ -884,13 +883,10 @@ def version_10(cls, ctx, node, **kwargs):
884
883
outputname = utils .make_name ("{}__newendmask" .format (node .name ))
885
884
new_end_mask = math .make_min_or_max_op (ctx , "Max" , [end .output [0 ], end_mask_const .output [0 ]],
886
885
[outputname ])
887
- is_reverse_steps = ctx .make_node ("Less" , [strides .output [0 ], zero_const .output [0 ]],
888
- op_name_scope = node .name )
889
- is_full_range = ctx .make_node ("Equal" , [new_end_mask .output [0 ], max_const .output [0 ]],
890
- op_name_scope = node .name )
891
- is_reverse_and_full_range = ctx .make_node ("And" , [is_full_range .output [0 ], is_reverse_steps .output [0 ]],
892
- op_name_scope = node .name )
893
- final_end = ctx .make_node ("Where" , [is_reverse_and_full_range .output [0 ], min_const .output [0 ],
886
+ op1 = ctx .make_node ("Less" , [strides .output [0 ], zero_const .output [0 ]], op_name_scope = node .name )
887
+ op2 = ctx .make_node ("Equal" , [new_end_mask .output [0 ], max_const .output [0 ]], op_name_scope = node .name )
888
+ op3 = ctx .make_node ("And" , [op2 .output [0 ], op1 .output [0 ]], op_name_scope = node .name )
889
+ final_end = ctx .make_node ("Where" , [op3 .output [0 ], min_const .output [0 ],
894
890
new_end_mask .output [0 ]], op_name_scope = node .name )
895
891
end_output = final_end .output [0 ]
896
892
0 commit comments