Skip to content

Commit e98f0d6

Browse files
committed
refactor indexing utilities
1 parent 3a7961f commit e98f0d6

File tree

1 file changed

+14
-26
lines changed

1 file changed

+14
-26
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -845,14 +845,16 @@ def _nonzero_impl(ary):
845845
return res
846846

847847

848-
def _validate_indices(inds, queue_list, usm_type_list):
848+
def _get_indices_queue_usm_type(inds, queue, usm_type):
849849
"""
850850
Utility for validating indices are NumPy ndarray or usm_ndarray of integral
851851
dtype or Python integers. At least one must be an array.
852852
853853
For each array, the queue and usm type are appended to `queue_list` and
854854
`usm_type_list`, respectively.
855855
"""
856+
queues = [queue]
857+
usm_types = [usm_type]
856858
any_array = False
857859
for ind in inds:
858860
if isinstance(ind, (np.ndarray, dpt.usm_ndarray)):
@@ -863,8 +865,8 @@ def _validate_indices(inds, queue_list, usm_type_list):
863865
"type"
864866
)
865867
if isinstance(ind, dpt.usm_ndarray):
866-
queue_list.append(ind.sycl_queue)
867-
usm_type_list.append(ind.usm_type)
868+
queues.append(ind.sycl_queue)
869+
usm_types.append(ind.usm_type)
868870
elif not isinstance(ind, Integral):
869871
raise TypeError(
870872
"all elements of `ind` expected to be usm_ndarrays, "
@@ -874,7 +876,9 @@ def _validate_indices(inds, queue_list, usm_type_list):
874876
raise TypeError(
875877
"at least one element of `inds` expected to be an array"
876878
)
877-
return inds
879+
usm_type = dpctl.utils.get_coerced_usm_type(usm_types)
880+
q = dpctl.utils.get_execution_queue(queues)
881+
return q, usm_type
878882

879883

880884
def _prepare_indices_arrays(inds, q, usm_type):
@@ -931,18 +935,12 @@ def _take_multi_index(ary, inds, p, mode=0):
931935
raise ValueError(
932936
"Invalid value for mode keyword, only 0 or 1 is supported"
933937
)
934-
queues_ = [
935-
ary.sycl_queue,
936-
]
937-
usm_types_ = [
938-
ary.usm_type,
939-
]
940938
if not isinstance(inds, (list, tuple)):
941939
inds = (inds,)
942940

943-
_validate_indices(inds, queues_, usm_types_)
944-
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
945-
exec_q = dpctl.utils.get_execution_queue(queues_)
941+
exec_q, res_usm_type = _get_indices_queue_usm_type(
942+
inds, ary.sycl_queue, ary.usm_type
943+
)
946944
if exec_q is None:
947945
raise dpctl.utils.ExecutionPlacementError(
948946
"Can not automatically determine where to allocate the "
@@ -1068,23 +1066,13 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10681066
raise ValueError(
10691067
"Invalid value for mode keyword, only 0 or 1 is supported"
10701068
)
1071-
if isinstance(vals, dpt.usm_ndarray):
1072-
queues_ = [ary.sycl_queue, vals.sycl_queue]
1073-
usm_types_ = [ary.usm_type, vals.usm_type]
1074-
else:
1075-
queues_ = [
1076-
ary.sycl_queue,
1077-
]
1078-
usm_types_ = [
1079-
ary.usm_type,
1080-
]
10811069
if not isinstance(inds, (list, tuple)):
10821070
inds = (inds,)
10831071

1084-
_validate_indices(inds, queues_, usm_types_)
1072+
exec_q, vals_usm_type = _get_indices_queue_usm_type(
1073+
inds, ary.sycl_queue, ary.usm_type
1074+
)
10851075

1086-
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
1087-
exec_q = dpctl.utils.get_execution_queue(queues_)
10881076
if exec_q is not None:
10891077
if not isinstance(vals, dpt.usm_ndarray):
10901078
vals = dpt.asarray(

0 commit comments

Comments
 (0)