Skip to content

Commit e052ee2

Browse files
authored
Merge pull request #2193 from IntelPython/update-dlpack-v1.2
Update dlpack to v1.2
2 parents 6859311 + 043481b commit e052ee2

File tree

4 files changed

+389
-84
lines changed

4 files changed

+389
-84
lines changed

dpctl/tensor/_dlpack.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ cdef extern from "dlpack/dlpack.h" nogil:
4141
int device_WebGPU "kDLWebGPU"
4242
int device_Hexagon "kDLHexagon"
4343
int device_MAIA "kDLMAIA"
44+
int device_Trn "kDLTrn"
4445

4546
cpdef object to_dlpack_capsule(usm_ndarray array) except +
4647
cpdef object to_dlpack_versioned_capsule(

dpctl/tensor/_dlpack.pyx

Lines changed: 68 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ from .._backend cimport (
3636
DPCTLSyclDeviceRef,
3737
DPCTLSyclUSMRef,
3838
)
39-
from ._usmarray cimport USM_ARRAY_C_CONTIGUOUS, USM_ARRAY_WRITABLE, usm_ndarray
39+
from ._usmarray cimport USM_ARRAY_WRITABLE, usm_ndarray
4040

4141
import ctypes
4242

@@ -76,6 +76,7 @@ cdef extern from "dlpack/dlpack.h" nogil:
7676
kDLWebGPU
7777
kDLHexagon
7878
kDLMAIA
79+
kDLTrn
7980

8081
ctypedef struct DLDevice:
8182
DLDeviceType device_type
@@ -88,6 +89,17 @@ cdef extern from "dlpack/dlpack.h" nogil:
8889
kDLBfloat
8990
kDLComplex
9091
kDLBool
92+
kDLFloat8_e3m4
93+
kDLFloat8_e4m3
94+
kDLFloat8_e4m3b11fnuz
95+
kDLFloat8_e4m3fn
96+
kDLFloat8_e4m3fnuz
97+
kDLFloat8_e5m2
98+
kDLFloat8_e5m2fnuz
99+
kDLFloat8_e8m0fnu
100+
kDLFloat6_e2m3fn
101+
kDLFloat6_e3m2fn
102+
kDLFloat4_e2m1fn
91103

92104
ctypedef struct DLDataType:
93105
uint8_t code
@@ -254,7 +266,6 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
254266
cdef int64_t *shape_strides_ptr = NULL
255267
cdef int i = 0
256268
cdef int device_id = -1
257-
cdef int flags = 0
258269
cdef Py_ssize_t element_offset = 0
259270
cdef Py_ssize_t byte_offset = 0
260271
cdef Py_ssize_t si = 1
@@ -269,22 +280,21 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
269280
raise MemoryError(
270281
"to_dlpack_capsule: Could not allocate memory for DLManagedTensor"
271282
)
272-
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
273-
if shape_strides_ptr is NULL:
274-
stdlib.free(dlm_tensor)
275-
raise MemoryError(
276-
"to_dlpack_capsule: Could not allocate memory for shape/strides"
277-
)
278-
shape_ptr = usm_ary.get_shape()
279-
for i in range(nd):
280-
shape_strides_ptr[i] = shape_ptr[i]
281-
strides_ptr = usm_ary.get_strides()
282-
flags = usm_ary.flags_
283-
if strides_ptr:
283+
if nd > 0:
284+
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
285+
if shape_strides_ptr is NULL:
286+
stdlib.free(dlm_tensor)
287+
raise MemoryError(
288+
"to_dlpack_capsule: Could not allocate memory for shape/strides"
289+
)
290+
shape_ptr = usm_ary.get_shape()
284291
for i in range(nd):
285-
shape_strides_ptr[nd + i] = strides_ptr[i]
286-
else:
287-
if not (flags & USM_ARRAY_C_CONTIGUOUS):
292+
shape_strides_ptr[i] = shape_ptr[i]
293+
strides_ptr = usm_ary.get_strides()
294+
if strides_ptr:
295+
for i in range(nd):
296+
shape_strides_ptr[nd + i] = strides_ptr[i]
297+
else:
288298
si = 1
289299
for i in range(0, nd):
290300
shape_strides_ptr[nd + i] = si
@@ -300,11 +310,8 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
300310
dl_tensor.data = <void*>(data_ptr - byte_offset)
301311
dl_tensor.ndim = nd
302312
dl_tensor.byte_offset = <uint64_t>byte_offset
303-
dl_tensor.shape = &shape_strides_ptr[0]
304-
if strides_ptr is NULL:
305-
dl_tensor.strides = NULL
306-
else:
307-
dl_tensor.strides = &shape_strides_ptr[nd]
313+
dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL
314+
dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL
308315
dl_tensor.device.device_type = kDLOneAPI
309316
dl_tensor.device.device_id = device_id
310317
dl_tensor.dtype.lanes = <uint16_t>1
@@ -384,24 +391,24 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
384391
"to_dlpack_versioned_capsule: Could not allocate memory "
385392
"for DLManagedTensorVersioned"
386393
)
387-
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
388-
if shape_strides_ptr is NULL:
389-
stdlib.free(dlmv_tensor)
390-
raise MemoryError(
391-
"to_dlpack_versioned_capsule: Could not allocate memory "
392-
"for shape/strides"
393-
)
394-
# this can be a separate function for handling shapes and strides
395-
shape_ptr = usm_ary.get_shape()
396-
for i in range(nd):
397-
shape_strides_ptr[i] = shape_ptr[i]
398-
strides_ptr = usm_ary.get_strides()
399-
flags = usm_ary.flags_
400-
if strides_ptr:
394+
if nd > 0:
395+
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
396+
if shape_strides_ptr is NULL:
397+
stdlib.free(dlmv_tensor)
398+
raise MemoryError(
399+
"to_dlpack_versioned_capsule: Could not allocate memory "
400+
"for shape/strides"
401+
)
402+
# this can be a separate function for handling shapes and strides
403+
shape_ptr = usm_ary.get_shape()
401404
for i in range(nd):
402-
shape_strides_ptr[nd + i] = strides_ptr[i]
403-
else:
404-
if not (flags & USM_ARRAY_C_CONTIGUOUS):
405+
shape_strides_ptr[i] = shape_ptr[i]
406+
strides_ptr = usm_ary.get_strides()
407+
flags = usm_ary.flags_
408+
if strides_ptr:
409+
for i in range(nd):
410+
shape_strides_ptr[nd + i] = strides_ptr[i]
411+
else:
405412
si = 1
406413
for i in range(0, nd):
407414
shape_strides_ptr[nd + i] = si
@@ -419,11 +426,8 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
419426
dl_tensor.data = <void*>(data_ptr - byte_offset)
420427
dl_tensor.ndim = nd
421428
dl_tensor.byte_offset = <uint64_t>byte_offset
422-
dl_tensor.shape = &shape_strides_ptr[0]
423-
if strides_ptr is NULL:
424-
dl_tensor.strides = NULL
425-
else:
426-
dl_tensor.strides = &shape_strides_ptr[nd]
429+
dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL
430+
dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL
427431
dl_tensor.device.device_type = kDLOneAPI
428432
dl_tensor.device.device_id = device_id
429433
dl_tensor.dtype.lanes = <uint16_t>1
@@ -503,10 +507,9 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
503507
"for DLManagedTensorVersioned"
504508
)
505509

506-
is_c_contiguous = npy_ary.flags["C"]
507510
shape = npy_ary.ctypes.shape_as(ctypes.c_int64)
508511
strides = npy_ary.ctypes.strides_as(ctypes.c_int64)
509-
if not is_c_contiguous:
512+
if nd > 0:
510513
if npy_ary.size != 1:
511514
for i in range(nd):
512515
if shape[i] != 1 and strides[i] % itemsize != 0:
@@ -517,18 +520,14 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
517520
"itemsize"
518521
)
519522
shape_strides_ptr = <int64_t *>stdlib.malloc((sizeof(int64_t) * 2) * nd)
520-
else:
521-
# no need to pass strides in this case
522-
shape_strides_ptr = <int64_t *>stdlib.malloc(sizeof(int64_t) * nd)
523-
if shape_strides_ptr is NULL:
524-
stdlib.free(dlmv_tensor)
525-
raise MemoryError(
526-
"numpy_to_dlpack_versioned_capsule: Could not allocate memory "
527-
"for shape/strides"
528-
)
529-
for i in range(nd):
530-
shape_strides_ptr[i] = shape[i]
531-
if not is_c_contiguous:
523+
if shape_strides_ptr is NULL:
524+
stdlib.free(dlmv_tensor)
525+
raise MemoryError(
526+
"numpy_to_dlpack_versioned_capsule: Could not allocate memory "
527+
"for shape/strides"
528+
)
529+
for i in range(nd):
530+
shape_strides_ptr[i] = shape[i]
532531
shape_strides_ptr[nd + i] = strides[i] // itemsize
533532

534533
writable_flag = npy_ary.flags["W"]
@@ -540,11 +539,8 @@ cpdef numpy_to_dlpack_versioned_capsule(ndarray npy_ary, bint copied):
540539
dl_tensor.data = <void *> npy_ary.data
541540
dl_tensor.ndim = nd
542541
dl_tensor.byte_offset = <uint64_t>byte_offset
543-
dl_tensor.shape = &shape_strides_ptr[0]
544-
if is_c_contiguous:
545-
dl_tensor.strides = NULL
546-
else:
547-
dl_tensor.strides = &shape_strides_ptr[nd]
542+
dl_tensor.shape = &shape_strides_ptr[0] if nd > 0 else NULL
543+
dl_tensor.strides = &shape_strides_ptr[nd] if nd > 0 else NULL
548544
dl_tensor.device.device_type = kDLCPU
549545
dl_tensor.device.device_id = 0
550546
dl_tensor.dtype.lanes = <uint16_t>1
@@ -816,12 +812,8 @@ cpdef object from_dlpack_capsule(object py_caps):
816812
raise BufferError(
817813
"Can not import DLPack tensor with lanes != 1"
818814
)
819-
offset_min = 0
820-
if dl_tensor.strides is NULL:
821-
for i in range(dl_tensor.ndim):
822-
sz = sz * dl_tensor.shape[i]
823-
offset_max = sz - 1
824-
else:
815+
if dl_tensor.ndim > 0:
816+
offset_min = 0
825817
offset_max = 0
826818
for i in range(dl_tensor.ndim):
827819
stride_i = dl_tensor.strides[i]
@@ -876,15 +868,17 @@ cpdef object from_dlpack_capsule(object py_caps):
876868
(<c_dpctl.SyclQueue>q).get_queue_ref(),
877869
memory_owner=tmp
878870
)
871+
879872
py_shape = list()
880-
for i in range(dl_tensor.ndim):
881-
py_shape.append(dl_tensor.shape[i])
882-
if (dl_tensor.strides is NULL):
883-
py_strides = None
884-
else:
873+
if (dl_tensor.shape is not NULL):
874+
for i in range(dl_tensor.ndim):
875+
py_shape.append(dl_tensor.shape[i])
876+
if (dl_tensor.strides is not NULL):
885877
py_strides = list()
886878
for i in range(dl_tensor.ndim):
887879
py_strides.append(dl_tensor.strides[i])
880+
else:
881+
py_strides = None
888882
if (dl_tensor.dtype.code == kDLUInt):
889883
ary_dt = np.dtype("u" + str(element_bytesize))
890884
elif (dl_tensor.dtype.code == kDLInt):

dpctl/tensor/_usmarray.pyx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class DLDeviceType(IntEnum):
8686
Qualcomm Hexagon DSP
8787
``kDLMAIA``:
8888
Microsoft MAIA device
89+
``kDLTrn``:
90+
AWS Trainium device
8991
"""
9092
kDLCPU = c_dlpack.device_CPU
9193
kDLCUDA = c_dlpack.device_CUDA
@@ -101,6 +103,7 @@ class DLDeviceType(IntEnum):
101103
kDLWebGPU = c_dlpack.device_WebGPU
102104
kDLHexagon = c_dlpack.device_Hexagon
103105
kDLMAIA = c_dlpack.device_MAIA
106+
kDLTrn = c_dlpack.device_Trn
104107

105108

106109
cdef class InternalUSMArrayError(Exception):

0 commit comments

Comments
 (0)