Skip to content

Commit e9a9633

Browse files
Exposed SyclQueue staticmethod to construct from context & dev.
Used the constructor in processing of sycl_usm_array_interface. Added test to check that it works.
1 parent b7ab952 commit e9a9633

File tree

5 files changed

+51
-4
lines changed

5 files changed

+51
-4
lines changed

dpctl/_backend.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ cdef extern from "dppl_sycl_queue_manager.h":
210210
DPPLSyclDeviceType DeviceTy,
211211
size_t DNum
212212
)
213+
cdef DPPLSyclQueueRef DPPLQueueMgr_GetQueueFromContextAndDevice(
214+
DPPLSyclContextRef CRef,
215+
DPPLSyclDeviceRef DRef)
213216

214217

215218
cdef extern from "dppl_sycl_usm_interface.h":

dpctl/_memory.pyx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ cdef class _BufferData:
9494
cdef _BufferData buf
9595
cdef Py_ssize_t arr_data_ptr
9696
cdef SyclDevice dev
97+
cdef SyclContext ctx
9798

9899
if ary_version != 1:
99100
_throw_sycl_usm_ary_iface()
@@ -128,8 +129,9 @@ cdef class _BufferData:
128129
#
129130
# cdef SyclQueue new_queue = SyclQueue._create_from_dev_context(dev, <SyclContext> ary_syclobj)
130131
# buf.queue = new_queue
131-
dev = Memory.get_pointer_device(buf.p, <SyclContext> ary_syclobj)
132-
buf.queue = get_current_queue()
132+
ctx = <SyclContext> ary_syclobj
133+
dev = Memory.get_pointer_device(buf.p, ctx)
134+
buf.queue = SyclQueue._create_from_context_and_device(ctx, dev)
133135

134136
return buf
135137

dpctl/_sycl_core.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ cdef class SyclQueue:
128128

129129
@staticmethod
130130
cdef SyclQueue _create (DPPLSyclQueueRef qref)
131+
@staticmethod
132+
cdef SyclQueue _create_from_context_and_device (SyclContext ctx, SyclDevice dev)
131133
cpdef bool equals (self, SyclQueue q)
132134
cpdef SyclContext get_sycl_context (self)
133135
cpdef SyclDevice get_sycl_device (self)

dpctl/_sycl_core.pyx

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,21 @@ cdef class SyclQueue:
373373
ret._queue_ref = qref
374374
return ret
375375

376+
@staticmethod
377+
cdef SyclQueue _create_from_context_and_device(SyclContext ctx, SyclDevice dev):
378+
cdef SyclQueue ret = SyclQueue.__new__(SyclQueue)
379+
cdef DPPLSyclContextRef cref = ctx.get_context_ref()
380+
cdef DPPLSyclDeviceRef dref = dev.get_device_ref()
381+
cdef DPPLSyclQueueRef qref = DPPLQueueMgr_GetQueueFromContextAndDevice(
382+
cref, dref)
383+
384+
if qref is NULL:
385+
raise SyclQueueCreationError("Queue creation failed.")
386+
ret._queue_ref = qref
387+
ret._context = ctx
388+
ret._device = dev
389+
return ret
390+
376391
def __dealloc__ (self):
377392
DPPLQueue_Delete(self._queue_ref)
378393

dpctl/tests/test_sycl_usm.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,21 @@
2424

2525
import unittest
2626
import dpctl
27-
from dpctl._memory import MemoryUSMShared, MemoryUSMHost, MemoryUSMDevice
28-
27+
from dpctl import MemoryUSMShared, MemoryUSMHost, MemoryUSMDevice
28+
import dpctl._memory
29+
import numpy as np
30+
31+
class Dummy(MemoryUSMShared):
32+
"""
33+
Class that exposes `__sycl_usm_array_interface__` with
34+
SYCL context for sycl object, instead of Sycl queue.
35+
"""
36+
@property
37+
def __sycl_usm_array_interface(self):
38+
iface = super().__sycl_usm_array_interface__
39+
iface['syclob'] = iface['syclobj'].get_sycl_context()
40+
return iface
41+
2942

3043
class TestMemory(unittest.TestCase):
3144
@unittest.skipUnless(
@@ -187,6 +200,18 @@ def test_create_with_only_size(self):
187200
self.assertEqual(m.nbytes, 1024)
188201
self.assertEqual(m.get_usm_type(), self.usm_type)
189202

203+
@unittest.skipUnless(
204+
dpctl.has_sycl_platforms(), "No SYCL Devices except the default host device."
205+
)
206+
def test_sycl_usm_array_interface(self):
207+
m = self.MemoryUSMClass(256)
208+
m2 = Dummy(m.nbytes)
209+
hb = np.random.randint(0, 256, size=256, dtype="|u1")
210+
m2.copy_from_host(hb)
211+
# test that USM array interface works with SyclContext as 'syclobj'
212+
m.copy_from_device(m2)
213+
self.assertTrue(np.array_equal(m.copy_to_host(), hb))
214+
190215

191216
class TestMemoryUSMShared(TestMemoryUSMBase, unittest.TestCase):
192217
""" Tests for MemoryUSMShared """

0 commit comments

Comments
 (0)