Skip to content

Commit 8b12038

Browse files
Fixes gh-1071
Whenever start of array decoded from its dlpack capsule data did not start from the beginning of the allocation, like with x0 = dpt.arrange(6); x1 = x0[::-2]; dpt.from_dlpack(x1) which failed, vs. x0 = dpt.arrange(7); x1 = x0[::-2]; dpt.from_dlpack(x1) which worked.
1 parent 8e63098 commit 8b12038

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,17 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
311311
cdef DLManagedTensor *dlm_tensor = NULL
312312
cdef bytes usm_type
313313
cdef size_t sz = 1
314+
cdef size_t alloc_sz = 1
314315
cdef int i
315316
cdef int device_id = -1
316317
cdef int element_bytesize = 0
317318
cdef Py_ssize_t offset_min = 0
318319
cdef Py_ssize_t offset_max = 0
319-
cdef int64_t stride_i
320320
cdef char *mem_ptr = NULL
321+
cdef Py_ssize_t mem_ptr_delta = 0
321322
cdef Py_ssize_t element_offset = 0
323+
cdef int64_t stride_i = -1
324+
cdef int64_t shape_i = -1
322325

323326
if not cpython.PyCapsule_IsValid(py_caps, 'dltensor'):
324327
if cpython.PyCapsule_IsValid(py_caps, 'used_dltensor'):
@@ -370,22 +373,22 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
370373
raise BufferError(
371374
"Can not import DLPack tensor with lanes != 1"
372375
)
376+
offset_min = 0
373377
if dlm_tensor.dl_tensor.strides is NULL:
374378
for i in range(dlm_tensor.dl_tensor.ndim):
375379
sz = sz * dlm_tensor.dl_tensor.shape[i]
380+
offset_max = sz - 1
376381
else:
377-
offset_min = 0
378382
offset_max = 0
379383
for i in range(dlm_tensor.dl_tensor.ndim):
380384
stride_i = dlm_tensor.dl_tensor.strides[i]
381-
if stride_i > 0:
382-
offset_max = offset_max + stride_i * (
383-
dlm_tensor.dl_tensor.shape[i] - 1
384-
)
385-
else:
386-
offset_min = offset_min + stride_i * (
387-
dlm_tensor.dl_tensor.shape[i] - 1
388-
)
385+
shape_i = dlm_tensor.dl_tensor.shape[i]
386+
if shape_i > 1:
387+
shape_i -= 1
388+
if stride_i > 0:
389+
offset_max = offset_max + stride_i * shape_i
390+
else:
391+
offset_min = offset_min + stride_i * shape_i
389392
sz = offset_max - offset_min + 1
390393
if sz == 0:
391394
sz = 1
@@ -401,14 +404,29 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
401404
if dlm_tensor.dl_tensor.data is NULL:
402405
usm_mem = dpmem.MemoryUSMDevice(sz, q)
403406
else:
404-
mem_ptr = <char *>dlm_tensor.dl_tensor.data + dlm_tensor.dl_tensor.byte_offset
405-
mem_ptr = mem_ptr - (element_offset * element_bytesize)
406-
usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
407+
mem_ptr_delta = dlm_tensor.dl_tensor.byte_offset - (
408+
element_offset * element_bytesize
409+
)
410+
mem_ptr = <char *>dlm_tensor.dl_tensor.data
411+
alloc_sz = dlm_tensor.dl_tensor.byte_offset + <uint64_t>(
412+
(offset_max + 1) * element_bytesize)
413+
tmp = c_dpmem._Memory.create_from_usm_pointer_size_qref(
407414
<DPCTLSyclUSMRef> mem_ptr,
408-
sz,
415+
max(alloc_sz, <uint64_t>element_bytesize),
409416
(<c_dpctl.SyclQueue>q).get_queue_ref(),
410417
memory_owner=dlm_holder
411418
)
419+
if mem_ptr_delta == 0:
420+
usm_mem = tmp
421+
else:
422+
alloc_sz = dlm_tensor.dl_tensor.byte_offset + <uint64_t>(
423+
(offset_max * element_bytesize + mem_ptr_delta))
424+
usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
425+
<DPCTLSyclUSMRef> (mem_ptr + (element_bytesize - mem_ptr_delta)),
426+
max(alloc_sz, <uint64_t>element_bytesize),
427+
(<c_dpctl.SyclQueue>q).get_queue_ref(),
428+
memory_owner=tmp
429+
)
412430
py_shape = list()
413431
for i in range(dlm_tensor.dl_tensor.ndim):
414432
py_shape.append(dlm_tensor.dl_tensor.shape[i])
@@ -427,7 +445,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
427445
elif (dlm_tensor.dl_tensor.dtype.code == kDLComplex):
428446
ary_dt = np.dtype("c" + str(element_bytesize))
429447
else:
430-
raise ValueError(
448+
raise BufferError(
431449
"Can not import DLPack tensor with type code {}.".format(
432450
<object>dlm_tensor.dl_tensor.dtype.code
433451
)
@@ -441,7 +459,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
441459
)
442460
return res_ary
443461
else:
444-
raise ValueError(
462+
raise BufferError(
445463
"The DLPack tensor resides on unsupported device."
446464
)
447465

0 commit comments

Comments
 (0)