@@ -845,14 +845,16 @@ def _nonzero_impl(ary):
845
845
return res
846
846
847
847
848
- def _validate_indices (inds , queue_list , usm_type_list ):
848
+ def _get_indices_queue_usm_type (inds , queue , usm_type ):
849
849
"""
850
850
Utility for validating indices are NumPy ndarray or usm_ndarray of integral
851
851
dtype or Python integers. At least one must be an array.
852
852
853
853
For each array, the queue and usm type are appended to `queue_list` and
854
854
`usm_type_list`, respectively.
855
855
"""
856
+ queues = [queue ]
857
+ usm_types = [usm_type ]
856
858
any_array = False
857
859
for ind in inds :
858
860
if isinstance (ind , (np .ndarray , dpt .usm_ndarray )):
@@ -863,8 +865,8 @@ def _validate_indices(inds, queue_list, usm_type_list):
863
865
"type"
864
866
)
865
867
if isinstance (ind , dpt .usm_ndarray ):
866
- queue_list .append (ind .sycl_queue )
867
- usm_type_list .append (ind .usm_type )
868
+ queues .append (ind .sycl_queue )
869
+ usm_types .append (ind .usm_type )
868
870
elif not isinstance (ind , Integral ):
869
871
raise TypeError (
870
872
"all elements of `ind` expected to be usm_ndarrays, "
@@ -874,7 +876,9 @@ def _validate_indices(inds, queue_list, usm_type_list):
874
876
raise TypeError (
875
877
"at least one element of `inds` expected to be an array"
876
878
)
877
- return inds
879
+ usm_type = dpctl .utils .get_coerced_usm_type (usm_types )
880
+ q = dpctl .utils .get_execution_queue (queues )
881
+ return q , usm_type
878
882
879
883
880
884
def _prepare_indices_arrays (inds , q , usm_type ):
@@ -931,18 +935,12 @@ def _take_multi_index(ary, inds, p, mode=0):
931
935
raise ValueError (
932
936
"Invalid value for mode keyword, only 0 or 1 is supported"
933
937
)
934
- queues_ = [
935
- ary .sycl_queue ,
936
- ]
937
- usm_types_ = [
938
- ary .usm_type ,
939
- ]
940
938
if not isinstance (inds , (list , tuple )):
941
939
inds = (inds ,)
942
940
943
- _validate_indices ( inds , queues_ , usm_types_ )
944
- res_usm_type = dpctl . utils . get_coerced_usm_type ( usm_types_ )
945
- exec_q = dpctl . utils . get_execution_queue ( queues_ )
941
+ exec_q , res_usm_type = _get_indices_queue_usm_type (
942
+ inds , ary . sycl_queue , ary . usm_type
943
+ )
946
944
if exec_q is None :
947
945
raise dpctl .utils .ExecutionPlacementError (
948
946
"Can not automatically determine where to allocate the "
@@ -1068,23 +1066,13 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
1068
1066
raise ValueError (
1069
1067
"Invalid value for mode keyword, only 0 or 1 is supported"
1070
1068
)
1071
- if isinstance (vals , dpt .usm_ndarray ):
1072
- queues_ = [ary .sycl_queue , vals .sycl_queue ]
1073
- usm_types_ = [ary .usm_type , vals .usm_type ]
1074
- else :
1075
- queues_ = [
1076
- ary .sycl_queue ,
1077
- ]
1078
- usm_types_ = [
1079
- ary .usm_type ,
1080
- ]
1081
1069
if not isinstance (inds , (list , tuple )):
1082
1070
inds = (inds ,)
1083
1071
1084
- _validate_indices (inds , queues_ , usm_types_ )
1072
+ exec_q , vals_usm_type = _get_indices_queue_usm_type (
1073
+ inds , ary .sycl_queue , ary .usm_type
1074
+ )
1085
1075
1086
- vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
1087
- exec_q = dpctl .utils .get_execution_queue (queues_ )
1088
1076
if exec_q is not None :
1089
1077
if not isinstance (vals , dpt .usm_ndarray ):
1090
1078
vals = dpt .asarray (
0 commit comments