Skip to content

Commit fe5f497

Browse files
Add static method wait for SyclEventRaw class
1 parent 1e45e4c commit fe5f497

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

dpctl/_backend.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ cdef extern from "dpctl_sycl_event_interface.h":
226226
cdef DPCTLSyclEventRef DPCTLEvent_Create()
227227
cdef DPCTLSyclEventRef DPCTLEvent_Copy(const DPCTLSyclEventRef ERef)
228228
cdef void DPCTLEvent_Wait(DPCTLSyclEventRef ERef)
229+
cdef void DPCTLEvent_WaitAndThrow(DPCTLSyclEventRef ERef)
229230
cdef void DPCTLEvent_Delete(DPCTLSyclEventRef ERef)
230231
cdef _event_status_type DPCTLEvent_GetCommandExecutionStatus(DPCTLSyclEventRef ERef)
231232
cdef _backend_type DPCTLEvent_GetBackend(DPCTLSyclEventRef ERef)

dpctl/_sycl_event.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,5 @@ cdef public class SyclEventRaw(_SyclEventRaw) [object PySyclEventRawObject, type
4646
cdef int _init_event_from__SyclEventRaw(self, _SyclEventRaw other)
4747
cdef int _init_event_from_SyclEvent(self, SyclEvent event)
4848
cdef int _init_event_from_capsule(self, object caps)
49-
cdef DPCTLSyclEventRef get_event_ref (self)
50-
cpdef void wait (self)
49+
cdef DPCTLSyclEventRef get_event_ref (self)
50+
cdef void _wait (SyclEventRaw event)

dpctl/_sycl_event.pyx

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ from ._backend cimport ( # noqa: E211
3737
DPCTLEvent_GetProfilingInfoSubmit,
3838
DPCTLEvent_GetWaitList,
3939
DPCTLEvent_Wait,
40+
DPCTLEvent_WaitAndThrow,
4041
DPCTLEventVector_Delete,
4142
DPCTLEventVector_GetAt,
4243
DPCTLEventVector_Size,
@@ -199,8 +200,20 @@ cdef class SyclEventRaw(_SyclEventRaw):
199200
"""
200201
return self._event_ref
201202

202-
cpdef void wait(self):
203-
DPCTLEvent_Wait(self._event_ref)
203+
@staticmethod
204+
cdef void _wait(SyclEventRaw event):
205+
DPCTLEvent_WaitAndThrow(event._event_ref)
206+
207+
@staticmethod
208+
def wait(event):
209+
if isinstance(event, list):
210+
for e in event:
211+
SyclEventRaw._wait(e)
212+
elif isinstance(event, SyclEventRaw):
213+
SyclEventRaw._wait(event)
214+
else:
215+
raise ValueError("The passed argument is not a list \
216+
or a SyclEventRaw type.")
204217

205218
def addressof_ref(self):
206219
""" Returns the address of the C API `DPCTLSyclEventRef` pointer as

dpctl/tests/test_sycl_event.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,23 @@ def test_create_event_raw_from_capsule():
8181
pytest.fail("Failed to create an event from capsule")
8282

8383

84+
def test_wait_with_event():
85+
event = dpctl.SyclEventRaw()
86+
try:
87+
dpctl.SyclEventRaw.wait(event)
88+
except ValueError:
89+
pytest.fail("Failed to wait for the event")
90+
91+
92+
def test_wait_with_list():
93+
event_1 = dpctl.SyclEventRaw()
94+
event_2 = dpctl.SyclEventRaw()
95+
try:
96+
dpctl.SyclEventRaw.wait([event_1, event_2])
97+
except ValueError:
98+
pytest.fail("Failed to wait for events from the list")
99+
100+
84101
def test_execution_status():
85102
event = dpctl.SyclEventRaw()
86103
try:

0 commit comments

Comments
 (0)