Skip to content

Commit 7666803

Browse files
Fixes #583
Sycl USM array interface should use the base pointer in the 'data' field, rather than its own 'data'. ``` In [1]: import numpy as np In [2]: import dpctl, dpctl.tensor as dpt, dpctl.memory as dpm In [3]: X = dpt.usm_ndarray(10, dtype="|f8") ...: Y = np.arange(X.size, dtype=X.dtype) ...: X.usm_data.copy_from_host(Y.view("|u1")) In [4]: print("X =", dpm.as_usm_memory(X).copy_to_host().view(X.dtype)) ...: print("X[0] =", dpm.as_usm_memory(X[0]).copy_to_host().view(X.dtype)) ...: print("X[1] =", dpm.as_usm_memory(X[1]).copy_to_host().view(X.dtype)) ...: print("X[2] =", dpm.as_usm_memory(X[2]).copy_to_host().view(X.dtype)) X = [0. 1. 2. 3. 4. 5. 6. 7. 8. 9.] X[0] = [0.] X[1] = [1.] X[2] = [2.] ```
1 parent 5f155bc commit 7666803

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ cdef class usm_ndarray:
308308
mem_ptr = <char *>(<size_t> ary_iface['data'][0])
309309
ary_ptr = <char *>(<size_t> self.data_)
310310
ro_flag = False if (self.flags_ & USM_ARRAY_WRITEABLE) else True
311-
ary_iface['data'] = (<size_t> ary_ptr, ro_flag)
311+
ary_iface['data'] = (<size_t> mem_ptr, ro_flag)
312312
ary_iface['shape'] = self.shape
313313
if (self.strides_):
314314
ary_iface['strides'] = _make_int_tuple(self.nd_, self.strides_)
@@ -335,7 +335,7 @@ cdef class usm_ndarray:
335335
"""
336336
Gives the number of indices needed to address elements of this array.
337337
"""
338-
return int(self.nd_)
338+
return self.nd_
339339

340340
@property
341341
def usm_data(self):

0 commit comments

Comments
 (0)