@@ -121,6 +121,39 @@ cdef void _managed_tensor_deleter(DLManagedTensor *dlm_tensor) noexcept with gil
121
121
dlm_tensor.manager_ctx = NULL
122
122
stdlib.free(dlm_tensor)
123
123
124
+ cdef object _get_default_context(c_dpctl.SyclDevice dev) except * :
125
+ try :
126
+ if _IS_LINUX:
127
+ default_context = dev.sycl_platform.default_context
128
+ else :
129
+ default_context = None
130
+ except RuntimeError :
131
+ # RT does not support default_context, e.g. Windows
132
+ default_context = None
133
+
134
+ return default_context
135
+
136
+
137
+ cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except * :
138
+ cdef DPCTLSyclDeviceRef pDRef = NULL
139
+ cdef DPCTLSyclDeviceRef tDRef = NULL
140
+ cdef c_dpctl.SyclDevice p_dev
141
+
142
+ pDRef = DPCTLDevice_GetParentDevice(dev.get_device_ref())
143
+ if pDRef is not NULL :
144
+ # if dev is a sub-device, find its parent
145
+ # and return its overall ordinal id
146
+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
147
+ while tDRef is not NULL :
148
+ DPCTLDevice_Delete(pDRef)
149
+ pDRef = tDRef
150
+ tDRef = DPCTLDevice_GetParentDevice(pDRef)
151
+ p_dev = c_dpctl.SyclDevice._create(pDRef)
152
+ return p_dev.get_overall_ordinal()
153
+
154
+ # return overall ordinal id of argument device
155
+ return dev.get_overall_ordinal()
156
+
124
157
125
158
cpdef to_dlpack_capsule(usm_ndarray usm_ary):
126
159
"""
@@ -168,14 +201,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
168
201
ary_sycl_queue = usm_ary.get_sycl_queue()
169
202
ary_sycl_device = ary_sycl_queue.get_sycl_device()
170
203
171
- try :
172
- if _IS_LINUX:
173
- default_context = ary_sycl_device.sycl_platform.default_context
174
- else :
175
- default_context = None
176
- except RuntimeError :
177
- # RT does not support default_context, e.g. Windows
178
- default_context = None
204
+ default_context = _get_default_context(ary_sycl_device)
179
205
if default_context is None :
180
206
# check that ary_sycl_device is a non-partitioned device
181
207
pDRef = DPCTLDevice_GetParentDevice(ary_sycl_device.get_device_ref())
0 commit comments