Skip to content

Commit 16c8996

Browse files
authored
Update _usmarray.pyx
1 parent 76a65bb commit 16c8996

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,20 +1585,35 @@ cdef class usm_ndarray:
15851585
return usm_ndarray_repr(self)
15861586

15871587
def __array__(self, dtype=None, /, *, copy=None):
1588-
"""NumPy's array protocol method to disallow implicit conversion.
1589-
1590-
Without this definition, `numpy.asarray(usm_ar)` converts
1591-
usm_ndarray instance into NumPy array with data type `object`
1592-
and every element being 0d usm_ndarray.
1593-
1594-
https://github.com/IntelPython/dpctl/pull/1384#issuecomment-1707212972
1595-
"""
1596-
raise TypeError(
1597-
"Implicit conversion to a NumPy array is not allowed. "
1598-
"Use `dpctl.tensor.asnumpy` to copy data from this "
1599-
"`dpctl.tensor.usm_ndarray` instance to NumPy array"
1588+
if copy is False:
1589+
raise TypeError("dpctl.tensors must copy data from device")
1590+
1591+
# it is assumed that copy=None requires a copy due to
1592+
# the change of a dpctl dtype to a NumPy dtype therefore
1593+
# not violating the NumPy standard for the __array__
1594+
# method.
1595+
1596+
if self.size == 0:
1597+
# no data needs to be copied for zero sized array
1598+
return np.ndarray(self.shape, dtype=self.dtype)
1599+
nb = self.usm_data.nbytes
1600+
q = self.sycl_queue
1601+
hh = dpmem.MemoryUSMHost(nb, queue=q)
1602+
h = np.ndarray(nb, dtype="u1", buffer=hh).view(self.dtype)
1603+
itsz = self.itemsize
1604+
strides_bytes = tuple(si * itsz for si in self.strides)
1605+
offset = self._element_offset * itsz
1606+
# ensure that content of ary.usm_data is final
1607+
q.wait()
1608+
hh.copy_from_device(self.usm_data)
1609+
ndarray = np.ndarray(
1610+
self.shape,
1611+
dtype=self.dtype,
1612+
buffer=h,
1613+
strides=strides_bytes,
1614+
offset=offset,
16001615
)
1601-
1616+
return ndarray if dtype is None else ndarray.astype(dtype)
16021617

16031618
cdef usm_ndarray _real_view(usm_ndarray ary):
16041619
"""

0 commit comments

Comments
 (0)