@@ -25,6 +25,7 @@ import numpy as np
25
25
import dpctl
26
26
import dpctl.memory as dpmem
27
27
28
+ from ._data_types import bool as dpt_bool
28
29
from ._device import Device
29
30
from ._print import usm_ndarray_repr, usm_ndarray_str
30
31
@@ -34,6 +35,7 @@ from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
34
35
cimport dpctl as c_dpctl
35
36
cimport dpctl.memory as c_dpmem
36
37
cimport dpctl.tensor._dlpack as c_dlpack
38
+
37
39
import dpctl.tensor._flags as _flags
38
40
39
41
include " _stride_utils.pxi"
@@ -648,6 +650,9 @@ cdef class usm_ndarray:
648
650
self .get_offset())
649
651
cdef usm_ndarray res
650
652
653
+ if len (_meta) < 5 :
654
+ raise RuntimeError
655
+
651
656
res = usm_ndarray.__new__ (
652
657
usm_ndarray,
653
658
_meta[0 ],
@@ -658,7 +663,32 @@ cdef class usm_ndarray:
658
663
)
659
664
res.flags_ |= (self .flags_ & USM_ARRAY_WRITABLE)
660
665
res.array_namespace_ = self .array_namespace_
661
- return res
666
+
667
+ adv_ind = _meta[3 ]
668
+ adv_ind_start_p = _meta[4 ]
669
+
670
+ if adv_ind_start_p < 0 :
671
+ return res
672
+
673
+ from ._copy_utils import (
674
+ _mock_extract,
675
+ _mock_nonzero,
676
+ _mock_take_multi_index,
677
+ )
678
+ if len (adv_ind) == 1 and adv_ind[0 ].dtype == dpt_bool:
679
+ return _mock_extract(res, adv_ind[0 ], adv_ind_start_p)
680
+
681
+ if any (ind.dtype == dpt_bool for ind in adv_ind):
682
+ adv_ind_int = list ()
683
+ for ind in adv_ind:
684
+ if ind.dtype == dpt_bool:
685
+ adv_ind_int.extend(_mock_nonzero(ind))
686
+ else :
687
+ adv_ind_int.append(ind)
688
+ return _mock_take_multi_index(res, tuple (adv_ind_int), adv_ind_start_p)
689
+
690
+ return _mock_take_multi_index(res, adv_ind, adv_ind_start_p)
691
+
662
692
663
693
def to_device (self , target ):
664
694
"""
@@ -959,39 +989,87 @@ cdef class usm_ndarray:
959
989
return _dispatch_binary_elementwise2(first, " right_shift" , other)
960
990
return NotImplemented
961
991
962
- def __setitem__ (self , key , val ):
963
- try :
964
- Xv = self .__getitem__ (key)
965
- except (ValueError , IndexError ) as e:
966
- raise e
992
+ def __setitem__ (self , key , rhs ):
993
+ cdef tuple _meta
994
+ cdef usm_ndarray Xv
995
+
996
+ if (self .flags_ & USM_ARRAY_WRITABLE) == 0 :
997
+ raise ValueError (" Can not modify read-only array." )
998
+
999
+ _meta = _basic_slice_meta(
1000
+ key, (< object > self ).shape, (< object > self ).strides,
1001
+ self .get_offset()
1002
+ )
1003
+
1004
+ if len (_meta) < 5 :
1005
+ raise RuntimeError
1006
+
1007
+ Xv = usm_ndarray.__new__ (
1008
+ usm_ndarray,
1009
+ _meta[0 ],
1010
+ dtype = _make_typestr(self .typenum_),
1011
+ strides = _meta[1 ],
1012
+ buffer = self .base_,
1013
+ offset = _meta[2 ],
1014
+ )
1015
+ # set flags and namespace
1016
+ Xv.flags_ |= (self .flags_ & USM_ARRAY_WRITABLE)
1017
+ Xv.array_namespace_ = self .array_namespace_
1018
+
967
1019
from ._copy_utils import (
968
1020
_copy_from_numpy_into,
969
1021
_copy_from_usm_ndarray_to_usm_ndarray,
1022
+ _mock_nonzero,
1023
+ _mock_place,
1024
+ _mock_put_multi_index,
970
1025
)
971
- if ((< usm_ndarray> Xv).flags_ & USM_ARRAY_WRITABLE) == 0 :
972
- raise ValueError (" Can not modify read-only array." )
973
- if isinstance (val, usm_ndarray):
974
- _copy_from_usm_ndarray_to_usm_ndarray(Xv, val)
975
- else :
976
- if hasattr (val, " __sycl_usm_array_interface__" ):
977
- from dpctl.tensor import asarray
978
- try :
979
- val_ar = asarray(val)
980
- _copy_from_usm_ndarray_to_usm_ndarray(Xv, val_ar)
981
- except Exception :
982
- raise ValueError (
983
- f" Input of type {type(val)} could not be "
984
- " converted to usm_ndarray"
985
- )
1026
+
1027
+ adv_ind = _meta[3 ]
1028
+ adv_ind_start_p = _meta[4 ]
1029
+
1030
+ if adv_ind_start_p < 0 :
1031
+ # basic slicing
1032
+ if isinstance (rhs, usm_ndarray):
1033
+ _copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs)
986
1034
else :
987
- try :
988
- val_np = np.asarray(val)
989
- _copy_from_numpy_into(Xv, val_np)
990
- except Exception :
991
- raise ValueError (
992
- f" Input of type {type(val)} could not be "
993
- " converted to usm_ndarray"
994
- )
1035
+ if hasattr (rhs, " __sycl_usm_array_interface__" ):
1036
+ from dpctl.tensor import asarray
1037
+ try :
1038
+ rhs_ar = asarray(rhs)
1039
+ _copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs_ar)
1040
+ except Exception :
1041
+ raise ValueError (
1042
+ f" Input of type {type(rhs)} could not be "
1043
+ " converted to usm_ndarray"
1044
+ )
1045
+ else :
1046
+ try :
1047
+ rhs_np = np.asarray(rhs)
1048
+ _copy_from_numpy_into(Xv, rhs_np)
1049
+ except Exception :
1050
+ raise ValueError (
1051
+ f" Input of type {type(rhs)} could not be "
1052
+ " converted to usm_ndarray"
1053
+ )
1054
+ return
1055
+
1056
+ if len (adv_ind) == 1 and adv_ind[0 ].dtype == dpt_bool:
1057
+ _mock_place(Xv, adv_ind[0 ], adv_ind_start_p, rhs)
1058
+ return
1059
+
1060
+ if any (ind.dtype == dpt_bool for ind in adv_ind):
1061
+ adv_ind_int = list ()
1062
+ for ind in adv_ind:
1063
+ if ind.dtype == dpt_bool:
1064
+ adv_ind_int.extend(_mock_nonzero(ind))
1065
+ else :
1066
+ adv_ind_int.append(ind)
1067
+ _mock_put_multi_index(Xv, tuple (adv_ind_int), adv_ind_start_p, rhs)
1068
+ return
1069
+
1070
+ _mock_put_multi_index(Xv, adv_ind, adv_ind_start_p, rhs)
1071
+ return
1072
+
995
1073
996
1074
def __sub__ (first , other ):
997
1075
" See comment in __add__"
0 commit comments