Skip to content

Commit efcd9cb

Browse files
Deployed lazy implementation of advanced indexing to develop tests
1 parent 483a423 commit efcd9cb

File tree

2 files changed

+230
-29
lines changed

2 files changed

+230
-29
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
# limitations under the License.
1616
import numpy as np
1717

18+
import dpctl
1819
import dpctl.memory as dpm
1920
import dpctl.tensor as dpt
2021
import dpctl.tensor._tensor_impl as ti
22+
import dpctl.utils
2123
from dpctl.tensor._device import normalize_queue_device
2224

2325
__doc__ = (
@@ -382,3 +384,124 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True):
382384
)
383385
_copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary)
384386
return R
387+
388+
389+
def _mock_extract(ary, ary_mask, p):
390+
exec_q = dpctl.utils.get_execution_queue(
391+
(
392+
ary.sycl_queue,
393+
ary_mask.sycl_queue,
394+
)
395+
)
396+
if exec_q is None:
397+
raise dpctl.utils.ExecutionPlacementError(
398+
"Can not automatically determine where to allocate the "
399+
"result or performance execution. "
400+
"Use `usm_ndarray.to_device` method to migrate data to "
401+
"be associated with the same queue."
402+
)
403+
404+
res_usm_type = dpctl.utils.get_coerced_usm_type(
405+
(
406+
ary.usm_type,
407+
ary_mask.usm_type,
408+
)
409+
)
410+
ary_np = dpt.asnumpy(ary)
411+
mask_np = dpt.asnumpy(ary_mask)
412+
res_np = ary_np[(slice(None),) * p + (mask_np,)]
413+
res = dpt.empty(
414+
res_np.shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
415+
)
416+
res[...] = res_np
417+
return res
418+
419+
420+
def _mock_nonzero(ary):
421+
if not isinstance(ary, dpt.usm_ndarray):
422+
raise TypeError
423+
q = ary.sycl_queue
424+
usm_type = ary.usm_type
425+
ary_np = dpt.asnumpy(ary)
426+
nz = ary_np.nonzero()
427+
return tuple(dpt.asarray(i, usm_type=usm_type, sycl_queue=q) for i in nz)
428+
429+
430+
def _mock_take_multi_index(ary, inds, p):
431+
queues_ = [
432+
ary.sycl_queue,
433+
]
434+
usm_types_ = [
435+
ary.usm_type,
436+
]
437+
all_integers = True
438+
for ind in inds:
439+
queues_.append(ind.sycl_queue)
440+
usm_types_.append(ind.usm_type)
441+
if all_integers:
442+
all_integers = ind.dtype.kind in "ui"
443+
exec_q = dpctl.utils.get_execution_queue(queues_)
444+
if exec_q is None:
445+
raise dpctl.utils.ExecutionPlacementError("")
446+
if not all_integers:
447+
raise IndexError(
448+
"arrays used as indices must be of integer (or boolean) type"
449+
)
450+
ary_np = dpt.asnumpy(ary)
451+
ind_np = (slice(None),) * p + tuple(dpt.asnumpy(ind) for ind in inds)
452+
res_np = ary_np[ind_np]
453+
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
454+
res = dpt.empty(
455+
res_np.shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
456+
)
457+
res[...] = res_np
458+
return res
459+
460+
461+
def _mock_place(ary, ary_mask, p, vals):
462+
exec_q = dpctl.utils.get_execution_queue(
463+
(ary.sycl_queue, ary_mask.sycl_queue, vals.sycl_queue)
464+
)
465+
if exec_q is None:
466+
raise dpctl.utils.ExecutionPlacementError(
467+
"Can not automatically determine where to allocate the "
468+
"result or performance execution. "
469+
"Use `usm_ndarray.to_device` method to migrate data to "
470+
"be associated with the same queue."
471+
)
472+
473+
ary_np = dpt.asnumpy(ary)
474+
mask_np = dpt.asnumpy(ary_mask)
475+
vals_np = dpt.asnumpy(vals)
476+
ary_np[(slice(None),) * p + (mask_np,)] = vals_np
477+
ary[...] = ary_np
478+
return
479+
480+
481+
def _mock_put_multi_index(ary, inds, p, vals):
482+
queues_ = [ary.sycl_queue, vals.sycl_queue]
483+
usm_types_ = [ary.usm_type, vals.usm_type]
484+
all_integers = True
485+
for ind in inds:
486+
queues_.append(ind.sycl_queue)
487+
usm_types_.append(ind.usm_type)
488+
if all_integers:
489+
all_integers = ind.dtype.kind in "ui"
490+
exec_q = dpctl.utils.get_execution_queue(queues_)
491+
if exec_q is None:
492+
raise dpctl.utils.ExecutionPlacementError(
493+
"Can not automatically determine where to allocate the "
494+
"result or performance execution. "
495+
"Use `usm_ndarray.to_device` method to migrate data to "
496+
"be associated with the same queue."
497+
)
498+
if not all_integers:
499+
raise IndexError(
500+
"arrays used as indices must be of integer (or boolean) type"
501+
)
502+
ary_np = dpt.asnumpy(ary)
503+
vals_np = dpt.asnumpy(vals)
504+
ind_np = (slice(None),) * p + tuple(dpt.asnumpy(ind) for ind in inds)
505+
ary_np[ind_np] = vals_np
506+
ary[...] = ary_np
507+
return

dpctl/tensor/_usmarray.pyx

Lines changed: 107 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import numpy as np
2525
import dpctl
2626
import dpctl.memory as dpmem
2727

28+
from ._data_types import bool as dpt_bool
2829
from ._device import Device
2930
from ._print import usm_ndarray_repr, usm_ndarray_str
3031

@@ -34,6 +35,7 @@ from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
3435
cimport dpctl as c_dpctl
3536
cimport dpctl.memory as c_dpmem
3637
cimport dpctl.tensor._dlpack as c_dlpack
38+
3739
import dpctl.tensor._flags as _flags
3840

3941
include "_stride_utils.pxi"
@@ -648,6 +650,9 @@ cdef class usm_ndarray:
648650
self.get_offset())
649651
cdef usm_ndarray res
650652

653+
if len(_meta) < 5:
654+
raise RuntimeError
655+
651656
res = usm_ndarray.__new__(
652657
usm_ndarray,
653658
_meta[0],
@@ -658,7 +663,32 @@ cdef class usm_ndarray:
658663
)
659664
res.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE)
660665
res.array_namespace_ = self.array_namespace_
661-
return res
666+
667+
adv_ind = _meta[3]
668+
adv_ind_start_p = _meta[4]
669+
670+
if adv_ind_start_p < 0:
671+
return res
672+
673+
from ._copy_utils import (
674+
_mock_extract,
675+
_mock_nonzero,
676+
_mock_take_multi_index,
677+
)
678+
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
679+
return _mock_extract(res, adv_ind[0], adv_ind_start_p)
680+
681+
if any(ind.dtype == dpt_bool for ind in adv_ind):
682+
adv_ind_int = list()
683+
for ind in adv_ind:
684+
if ind.dtype == dpt_bool:
685+
adv_ind_int.extend(_mock_nonzero(ind))
686+
else:
687+
adv_ind_int.append(ind)
688+
return _mock_take_multi_index(res, tuple(adv_ind_int), adv_ind_start_p)
689+
690+
return _mock_take_multi_index(res, adv_ind, adv_ind_start_p)
691+
662692

663693
def to_device(self, target):
664694
"""
@@ -959,39 +989,87 @@ cdef class usm_ndarray:
959989
return _dispatch_binary_elementwise2(first, "right_shift", other)
960990
return NotImplemented
961991

962-
def __setitem__(self, key, val):
963-
try:
964-
Xv = self.__getitem__(key)
965-
except (ValueError, IndexError) as e:
966-
raise e
992+
def __setitem__(self, key, rhs):
993+
cdef tuple _meta
994+
cdef usm_ndarray Xv
995+
996+
if (self.flags_ & USM_ARRAY_WRITABLE) == 0:
997+
raise ValueError("Can not modify read-only array.")
998+
999+
_meta = _basic_slice_meta(
1000+
key, (<object>self).shape, (<object> self).strides,
1001+
self.get_offset()
1002+
)
1003+
1004+
if len(_meta) < 5:
1005+
raise RuntimeError
1006+
1007+
Xv = usm_ndarray.__new__(
1008+
usm_ndarray,
1009+
_meta[0],
1010+
dtype=_make_typestr(self.typenum_),
1011+
strides=_meta[1],
1012+
buffer=self.base_,
1013+
offset=_meta[2],
1014+
)
1015+
# set flags and namespace
1016+
Xv.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE)
1017+
Xv.array_namespace_ = self.array_namespace_
1018+
9671019
from ._copy_utils import (
9681020
_copy_from_numpy_into,
9691021
_copy_from_usm_ndarray_to_usm_ndarray,
1022+
_mock_nonzero,
1023+
_mock_place,
1024+
_mock_put_multi_index,
9701025
)
971-
if ((<usm_ndarray> Xv).flags_ & USM_ARRAY_WRITABLE) == 0:
972-
raise ValueError("Can not modify read-only array.")
973-
if isinstance(val, usm_ndarray):
974-
_copy_from_usm_ndarray_to_usm_ndarray(Xv, val)
975-
else:
976-
if hasattr(val, "__sycl_usm_array_interface__"):
977-
from dpctl.tensor import asarray
978-
try:
979-
val_ar = asarray(val)
980-
_copy_from_usm_ndarray_to_usm_ndarray(Xv, val_ar)
981-
except Exception:
982-
raise ValueError(
983-
f"Input of type {type(val)} could not be "
984-
"converted to usm_ndarray"
985-
)
1026+
1027+
adv_ind = _meta[3]
1028+
adv_ind_start_p = _meta[4]
1029+
1030+
if adv_ind_start_p < 0:
1031+
# basic slicing
1032+
if isinstance(rhs, usm_ndarray):
1033+
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs)
9861034
else:
987-
try:
988-
val_np = np.asarray(val)
989-
_copy_from_numpy_into(Xv, val_np)
990-
except Exception:
991-
raise ValueError(
992-
f"Input of type {type(val)} could not be "
993-
"converted to usm_ndarray"
994-
)
1035+
if hasattr(rhs, "__sycl_usm_array_interface__"):
1036+
from dpctl.tensor import asarray
1037+
try:
1038+
rhs_ar = asarray(rhs)
1039+
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs_ar)
1040+
except Exception:
1041+
raise ValueError(
1042+
f"Input of type {type(rhs)} could not be "
1043+
"converted to usm_ndarray"
1044+
)
1045+
else:
1046+
try:
1047+
rhs_np = np.asarray(rhs)
1048+
_copy_from_numpy_into(Xv, rhs_np)
1049+
except Exception:
1050+
raise ValueError(
1051+
f"Input of type {type(rhs)} could not be "
1052+
"converted to usm_ndarray"
1053+
)
1054+
return
1055+
1056+
if len(adv_ind) == 1 and adv_ind[0].dtype == dpt_bool:
1057+
_mock_place(Xv, adv_ind[0], adv_ind_start_p, rhs)
1058+
return
1059+
1060+
if any(ind.dtype == dpt_bool for ind in adv_ind):
1061+
adv_ind_int = list()
1062+
for ind in adv_ind:
1063+
if ind.dtype == dpt_bool:
1064+
adv_ind_int.extend(_mock_nonzero(ind))
1065+
else:
1066+
adv_ind_int.append(ind)
1067+
_mock_put_multi_index(Xv, tuple(adv_ind_int), adv_ind_start_p, rhs)
1068+
return
1069+
1070+
_mock_put_multi_index(Xv, adv_ind, adv_ind_start_p, rhs)
1071+
return
1072+
9951073

9961074
def __sub__(first, other):
9971075
"See comment in __add__"

0 commit comments

Comments
 (0)