Skip to content

Commit e3d0838

Browse files
__dlpack_device__() returns ID of the parent device
It gives ID (position in `sycl::device::get_devices()`) for the root device that the array allocation device descends from for sub-devices, or ID of the allocation device is that device is a root device.
1 parent 6efb2c9 commit e3d0838

File tree

3 files changed

+39
-10
lines changed

3 files changed

+39
-10
lines changed

dpctl/tensor/_dlpack.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# cython: language_level=3
1919
# cython: linetrace=True
2020

21+
from .._sycl_device cimport SyclDevice
2122
from ._usmarray cimport usm_ndarray
2223

2324

@@ -32,6 +33,8 @@ cpdef usm_ndarray from_dlpack_capsule(object dltensor) except +
3233

3334
cpdef from_dlpack(array)
3435

36+
cdef int get_parent_device_ordinal_id(SyclDevice dev) except *
37+
3538
cdef class DLPackCreationError(Exception):
3639
"""
3740
A DLPackCreateError exception is raised when constructing

dpctl/tensor/_dlpack.pyx

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,39 @@ cdef void _managed_tensor_deleter(DLManagedTensor *dlm_tensor) noexcept with gil
121121
dlm_tensor.manager_ctx = NULL
122122
stdlib.free(dlm_tensor)
123123

124+
cdef object _get_default_context(c_dpctl.SyclDevice dev) except *:
125+
try:
126+
if _IS_LINUX:
127+
default_context = dev.sycl_platform.default_context
128+
else:
129+
default_context = None
130+
except RuntimeError:
131+
# RT does not support default_context, e.g. Windows
132+
default_context = None
133+
134+
return default_context
135+
136+
137+
cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except *:
138+
cdef DPCTLSyclDeviceRef pDRef = NULL
139+
cdef DPCTLSyclDeviceRef tDRef = NULL
140+
cdef c_dpctl.SyclDevice p_dev
141+
142+
pDRef = DPCTLDevice_GetParentDevice(dev.get_device_ref())
143+
if pDRef is not NULL:
144+
# if dev is a sub-device, find its parent
145+
# and return its overall ordinal id
146+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
147+
while tDRef is not NULL:
148+
DPCTLDevice_Delete(pDRef)
149+
pDRef = tDRef
150+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
151+
p_dev = c_dpctl.SyclDevice._create(pDRef)
152+
return p_dev.get_overall_ordinal()
153+
154+
# return overall ordinal id of argument device
155+
return dev.get_overall_ordinal()
156+
124157

125158
cpdef to_dlpack_capsule(usm_ndarray usm_ary):
126159
"""
@@ -168,14 +201,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
168201
ary_sycl_queue = usm_ary.get_sycl_queue()
169202
ary_sycl_device = ary_sycl_queue.get_sycl_device()
170203

171-
try:
172-
if _IS_LINUX:
173-
default_context = ary_sycl_device.sycl_platform.default_context
174-
else:
175-
default_context = None
176-
except RuntimeError:
177-
# RT does not support default_context, e.g. Windows
178-
default_context = None
204+
default_context = _get_default_context(ary_sycl_device)
179205
if default_context is None:
180206
# check that ary_sycl_device is a non-partitioned device
181207
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())

dpctl/tensor/_usmarray.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -954,10 +954,10 @@ cdef class usm_ndarray:
954954
DLPackCreationError: when array is allocation on a partitioned
955955
SYCL device
956956
"""
957-
cdef int dev_id = (<c_dpctl.SyclDevice>self.sycl_device).get_overall_ordinal()
957+
cdef int dev_id = c_dlpack.get_parent_device_ordinal_id(<c_dpctl.SyclDevice>self.sycl_device)
958958
if dev_id < 0:
959959
raise c_dlpack.DLPackCreationError(
960-
"DLPack protocol is only supported for non-partitioned devices"
960+
"Could not determine id of the device where array was allocated."
961961
)
962962
else:
963963
return (

0 commit comments

Comments
 (0)