Skip to content

Commit 98ec732

Browse files
from_dlpack get queue from cache using context & device as key.
1 parent 1ee1e91 commit 98ec732

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

dpctl/tensor/_dlpack.pyx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint16_t, uint64_t
2424

2525
cimport dpctl as c_dpctl
2626
cimport dpctl.memory as c_dpmem
27+
from dpctl._sycl_queue_manager cimport get_device_cached_queue
2728

2829
from .._backend cimport (
2930
DPCTLDevice_Delete,
@@ -344,12 +345,12 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
344345
if _IS_LINUX:
345346
default_context = root_device.sycl_platform.default_context
346347
else:
347-
default_context = dpctl.SyclQueue(root_device).sycl_context
348+
default_context = get_device_cached_queue(root_device).sycl_context
348349
except RuntimeError:
349-
default_context = dpctl.SyclQueue(root_device).sycl_context
350+
default_context = get_device_cached_queue(root_device).sycl_context
350351
if dlm_tensor.dl_tensor.data is NULL:
351352
usm_type = b"device"
352-
q = dpctl.SyclQueue(default_context, root_device)
353+
q = get_device_cached_queue((default_context, root_device,))
353354
else:
354355
usm_type = c_dpmem._Memory.get_pointer_type(
355356
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
@@ -364,7 +365,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
364365
<DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
365366
<c_dpctl.SyclContext>default_context
366367
)
367-
q = dpctl.SyclQueue(default_context, alloc_device)
368+
q = get_device_cached_queue((default_context, alloc_device,))
368369
if dlm_tensor.dl_tensor.dtype.bits % 8:
369370
raise BufferError(
370371
"Can not import DLPack tensor whose element's "

0 commit comments

Comments
 (0)