@@ -311,14 +311,17 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
311
311
cdef DLManagedTensor * dlm_tensor = NULL
312
312
cdef bytes usm_type
313
313
cdef size_t sz = 1
314
+ cdef size_t alloc_sz = 1
314
315
cdef int i
315
316
cdef int device_id = - 1
316
317
cdef int element_bytesize = 0
317
318
cdef Py_ssize_t offset_min = 0
318
319
cdef Py_ssize_t offset_max = 0
319
- cdef int64_t stride_i
320
320
cdef char * mem_ptr = NULL
321
+ cdef Py_ssize_t mem_ptr_delta = 0
321
322
cdef Py_ssize_t element_offset = 0
323
+ cdef int64_t stride_i = - 1
324
+ cdef int64_t shape_i = - 1
322
325
323
326
if not cpython.PyCapsule_IsValid(py_caps, ' dltensor' ):
324
327
if cpython.PyCapsule_IsValid(py_caps, ' used_dltensor' ):
@@ -370,22 +373,22 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
370
373
raise BufferError(
371
374
" Can not import DLPack tensor with lanes != 1"
372
375
)
376
+ offset_min = 0
373
377
if dlm_tensor.dl_tensor.strides is NULL :
374
378
for i in range (dlm_tensor.dl_tensor.ndim):
375
379
sz = sz * dlm_tensor.dl_tensor.shape[i]
380
+ offset_max = sz - 1
376
381
else :
377
- offset_min = 0
378
382
offset_max = 0
379
383
for i in range (dlm_tensor.dl_tensor.ndim):
380
384
stride_i = dlm_tensor.dl_tensor.strides[i]
381
- if stride_i > 0 :
382
- offset_max = offset_max + stride_i * (
383
- dlm_tensor.dl_tensor.shape[i] - 1
384
- )
385
- else :
386
- offset_min = offset_min + stride_i * (
387
- dlm_tensor.dl_tensor.shape[i] - 1
388
- )
385
+ shape_i = dlm_tensor.dl_tensor.shape[i]
386
+ if shape_i > 1 :
387
+ shape_i -= 1
388
+ if stride_i > 0 :
389
+ offset_max = offset_max + stride_i * shape_i
390
+ else :
391
+ offset_min = offset_min + stride_i * shape_i
389
392
sz = offset_max - offset_min + 1
390
393
if sz == 0 :
391
394
sz = 1
@@ -401,14 +404,29 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
401
404
if dlm_tensor.dl_tensor.data is NULL :
402
405
usm_mem = dpmem.MemoryUSMDevice(sz, q)
403
406
else :
404
- mem_ptr = < char * > dlm_tensor.dl_tensor.data + dlm_tensor.dl_tensor.byte_offset
405
- mem_ptr = mem_ptr - (element_offset * element_bytesize)
406
- usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
407
+ mem_ptr_delta = dlm_tensor.dl_tensor.byte_offset - (
408
+ element_offset * element_bytesize
409
+ )
410
+ mem_ptr = < char * > dlm_tensor.dl_tensor.data
411
+ alloc_sz = dlm_tensor.dl_tensor.byte_offset + < uint64_t> (
412
+ (offset_max + 1 ) * element_bytesize)
413
+ tmp = c_dpmem._Memory.create_from_usm_pointer_size_qref(
407
414
< DPCTLSyclUSMRef> mem_ptr,
408
- sz ,
415
+ max (alloc_sz, < uint64_t > element_bytesize) ,
409
416
(< c_dpctl.SyclQueue> q).get_queue_ref(),
410
417
memory_owner = dlm_holder
411
418
)
419
+ if mem_ptr_delta == 0 :
420
+ usm_mem = tmp
421
+ else :
422
+ alloc_sz = dlm_tensor.dl_tensor.byte_offset + < uint64_t> (
423
+ (offset_max * element_bytesize + mem_ptr_delta))
424
+ usm_mem = c_dpmem._Memory.create_from_usm_pointer_size_qref(
425
+ < DPCTLSyclUSMRef> (mem_ptr + (element_bytesize - mem_ptr_delta)),
426
+ max (alloc_sz, < uint64_t> element_bytesize),
427
+ (< c_dpctl.SyclQueue> q).get_queue_ref(),
428
+ memory_owner = tmp
429
+ )
412
430
py_shape = list ()
413
431
for i in range (dlm_tensor.dl_tensor.ndim):
414
432
py_shape.append(dlm_tensor.dl_tensor.shape[i])
@@ -427,7 +445,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
427
445
elif (dlm_tensor.dl_tensor.dtype.code == kDLComplex):
428
446
ary_dt = np.dtype(" c" + str (element_bytesize))
429
447
else :
430
- raise ValueError (
448
+ raise BufferError (
431
449
" Can not import DLPack tensor with type code {}." .format(
432
450
< object > dlm_tensor.dl_tensor.dtype.code
433
451
)
@@ -441,7 +459,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
441
459
)
442
460
return res_ary
443
461
else :
444
- raise ValueError (
462
+ raise BufferError (
445
463
" The DLPack tensor resides on unsupported device."
446
464
)
447
465
0 commit comments