Skip to content

Commit 3a7961f

Browse files
committed
permit NumPy arrays in indexing
1 parent aa05645 commit 3a7961f

File tree

2 files changed

+62
-42
lines changed

2 files changed

+62
-42
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -756,20 +756,28 @@ def _extract_impl(ary, ary_mask, axis=0):
756756
raise TypeError(
757757
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
758758
)
759-
if not isinstance(ary_mask, dpt.usm_ndarray):
760-
raise TypeError(
761-
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
759+
if isinstance(ary_mask, dpt.usm_ndarray):
760+
dst_usm_type = dpctl.utils.get_coerced_usm_type(
761+
(ary.usm_type, ary_mask.usm_type)
762762
)
763-
dst_usm_type = dpctl.utils.get_coerced_usm_type(
764-
(ary.usm_type, ary_mask.usm_type)
765-
)
766-
exec_q = dpctl.utils.get_execution_queue(
767-
(ary.sycl_queue, ary_mask.sycl_queue)
768-
)
769-
if exec_q is None:
770-
raise dpctl.utils.ExecutionPlacementError(
771-
"arrays have different associated queues. "
772-
"Use `y.to_device(x.device)` to migrate."
763+
exec_q = dpctl.utils.get_execution_queue(
764+
(ary.sycl_queue, ary_mask.sycl_queue)
765+
)
766+
if exec_q is None:
767+
raise dpctl.utils.ExecutionPlacementError(
768+
"arrays have different associated queues. "
769+
"Use `y.to_device(x.device)` to migrate."
770+
)
771+
elif isinstance(ary_mask, np.ndarray):
772+
dst_usm_type = ary.usm_type
773+
exec_q = ary.sycl_queue
774+
ary_mask = dpt.asarray(
775+
ary_mask, usm_type=dst_usm_type, sycl_queue=exec_q
776+
)
777+
else:
778+
raise TypeError(
779+
"Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
780+
f"{type(ary_mask)}"
773781
)
774782
ary_nd = ary.ndim
775783
pp = normalize_axis_index(operator.index(axis), ary_nd)
@@ -839,31 +847,32 @@ def _nonzero_impl(ary):
839847

840848
def _validate_indices(inds, queue_list, usm_type_list):
841849
"""
842-
Utility for validating indices are usm_ndarray of integral dtype or Python
843-
integers. At least one must be an array.
850+
Utility for validating indices are NumPy ndarray or usm_ndarray of integral
851+
dtype or Python integers. At least one must be an array.
844852
845853
For each array, the queue and usm type are appended to `queue_list` and
846854
`usm_type_list`, respectively.
847855
"""
848-
any_usmarray = False
856+
any_array = False
849857
for ind in inds:
850-
if isinstance(ind, dpt.usm_ndarray):
851-
any_usmarray = True
858+
if isinstance(ind, (np.ndarray, dpt.usm_ndarray)):
859+
any_array = True
852860
if ind.dtype.kind not in "ui":
853861
raise IndexError(
854862
"arrays used as indices must be of integer (or boolean) "
855863
"type"
856864
)
857-
queue_list.append(ind.sycl_queue)
858-
usm_type_list.append(ind.usm_type)
865+
if isinstance(ind, dpt.usm_ndarray):
866+
queue_list.append(ind.sycl_queue)
867+
usm_type_list.append(ind.usm_type)
859868
elif not isinstance(ind, Integral):
860869
raise TypeError(
861-
"all elements of `ind` expected to be usm_ndarrays "
862-
f"or integers, found {type(ind)}"
870+
"all elements of `ind` expected to be usm_ndarrays, "
871+
f"NumPy arrays, or integers, found {type(ind)}"
863872
)
864-
if not any_usmarray:
873+
if not any_array:
865874
raise TypeError(
866-
"at least one element of `inds` expected to be a usm_ndarray"
875+
"at least one element of `inds` expected to be an array"
867876
)
868877
return inds
869878

@@ -942,8 +951,7 @@ def _take_multi_index(ary, inds, p, mode=0):
942951
"be associated with the same queue."
943952
)
944953

945-
if len(inds) > 1:
946-
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)
954+
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)
947955

948956
ind0 = inds[0]
949957
ary_sh = ary.shape
@@ -976,16 +984,28 @@ def _place_impl(ary, ary_mask, vals, axis=0):
976984
raise TypeError(
977985
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
978986
)
979-
if not isinstance(ary_mask, dpt.usm_ndarray):
980-
raise TypeError(
981-
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
987+
if isinstance(ary_mask, dpt.usm_ndarray):
988+
exec_q = dpctl.utils.get_execution_queue(
989+
(
990+
ary.sycl_queue,
991+
ary_mask.sycl_queue,
992+
)
982993
)
983-
exec_q = dpctl.utils.get_execution_queue(
984-
(
985-
ary.sycl_queue,
986-
ary_mask.sycl_queue,
994+
if exec_q is None:
995+
raise dpctl.utils.ExecutionPlacementError(
996+
"arrays have different associated queues. "
997+
"Use `y.to_device(x.device)` to migrate."
998+
)
999+
elif isinstance(ary_mask, np.ndarray):
1000+
exec_q = ary.sycl_queue
1001+
ary_mask = dpt.asarray(
1002+
ary_mask, usm_type=ary.usm_type, sycl_queue=exec_q
1003+
)
1004+
else:
1005+
raise TypeError(
1006+
"Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
1007+
f"{type(ary_mask)}"
9871008
)
988-
)
9891009
if exec_q is not None:
9901010
if not isinstance(vals, dpt.usm_ndarray):
9911011
vals = dpt.asarray(vals, dtype=ary.dtype, sycl_queue=exec_q)
@@ -1080,8 +1100,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10801100
"be associated with the same queue."
10811101
)
10821102

1083-
if len(inds) > 1:
1084-
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)
1103+
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)
10851104

10861105
ind0 = inds[0]
10871106
ary_sh = ary.shape

dpctl/tensor/_slicing.pxi

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numbers
1818
from operator import index
1919
from cpython.buffer cimport PyObject_CheckBuffer
20+
from numpy import ndarray
2021

2122

2223
cdef bint _is_buffer(object o):
@@ -46,7 +47,7 @@ cdef Py_ssize_t _slice_len(
4647

4748
cdef bint _is_integral(object x) except *:
4849
"""Gives True if x is an integral slice spec"""
49-
if isinstance(x, usm_ndarray):
50+
if isinstance(x, (ndarray, usm_ndarray)):
5051
if x.ndim > 0:
5152
return False
5253
if x.dtype.kind not in "ui":
@@ -74,7 +75,7 @@ cdef bint _is_integral(object x) except *:
7475

7576
cdef bint _is_boolean(object x) except *:
7677
"""Gives True if x is an integral slice spec"""
77-
if isinstance(x, usm_ndarray):
78+
if isinstance(x, (ndarray, usm_ndarray)):
7879
if x.ndim > 0:
7980
return False
8081
if x.dtype.kind not in "b":
@@ -185,7 +186,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
185186
raise IndexError(
186187
"Index {0} is out of range for axes 0 with "
187188
"size {1}".format(ind, shape[0]))
188-
elif isinstance(ind, usm_ndarray):
189+
elif isinstance(ind, (ndarray, usm_ndarray)):
189190
return (shape, strides, offset, (ind,), 0)
190191
elif isinstance(ind, tuple):
191192
axes_referenced = 0
@@ -216,7 +217,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
216217
axes_referenced += 1
217218
if not array_streak_started and array_streak_interrupted:
218219
explicit_index += 1
219-
elif isinstance(i, usm_ndarray):
220+
elif isinstance(i, (ndarray, usm_ndarray)):
220221
if not seen_arrays_yet:
221222
seen_arrays_yet = True
222223
array_streak_started = True
@@ -302,7 +303,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
302303
array_streak = False
303304
elif _is_integral(ind_i):
304305
if array_streak:
305-
if not isinstance(ind_i, usm_ndarray):
306+
if not isinstance(ind_i, (ndarray, usm_ndarray)):
306307
ind_i = index(ind_i)
307308
# integer will be converted to an array,
308309
# still raise if OOB
@@ -337,7 +338,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
337338
"Index {0} is out of range for axes "
338339
"{1} with size {2}".format(ind_i, k, shape[k])
339340
)
340-
elif isinstance(ind_i, usm_ndarray):
341+
elif isinstance(ind_i, (ndarray, usm_ndarray)):
341342
if not array_streak:
342343
array_streak = True
343344
if not advanced_start_pos_set:

0 commit comments

Comments
 (0)