|
23 | 23 |
|
24 | 24 | import logging
|
25 | 25 |
|
26 |
| -from ._backend cimport DPCTLEvent_Delete, DPCTLEvent_Wait, DPCTLSyclEventRef |
| 26 | +from cpython cimport pycapsule |
| 27 | + |
| 28 | +from ._backend cimport ( # noqa: E211 |
| 29 | + DPCTLEvent_Copy, |
| 30 | + DPCTLEvent_Create, |
| 31 | + DPCTLEvent_Delete, |
| 32 | + DPCTLEvent_Wait, |
| 33 | + DPCTLSyclEventRef, |
| 34 | +) |
27 | 35 |
|
28 | 36 | __all__ = [
|
29 | 37 | "SyclEvent",
|
| 38 | + "SyclEventRaw", |
30 | 39 | ]
|
31 | 40 |
|
32 | 41 | _logger = logging.getLogger(__name__)
|
@@ -71,3 +80,132 @@ cdef class SyclEvent:
|
71 | 80 | SyclEvent cast to a size_t.
|
72 | 81 | """
|
73 | 82 | return int(<size_t>self._event_ref)
|
| 83 | + |
| 84 | +cdef void _event_capsule_deleter(object o): |
| 85 | + cdef DPCTLSyclEventRef ERef = NULL |
| 86 | + if pycapsule.PyCapsule_IsValid(o, "SyclEventRef"): |
| 87 | + ERef = <DPCTLSyclEventRef> pycapsule.PyCapsule_GetPointer( |
| 88 | + o, "SyclEventRef" |
| 89 | + ) |
| 90 | + DPCTLEvent_Delete(ERef) |
| 91 | + |
| 92 | + |
| 93 | +cdef class _SyclEventRaw: |
| 94 | + """ Python wrapper class for a ``cl::sycl::event``. |
| 95 | + """ |
| 96 | + |
| 97 | + def __dealloc__(self): |
| 98 | + DPCTLEvent_Delete(self._event_ref) |
| 99 | + |
| 100 | + |
| 101 | +cdef class SyclEventRaw(_SyclEventRaw): |
| 102 | + """ Python wrapper class for a ``cl::sycl::event``. |
| 103 | + """ |
| 104 | + |
| 105 | + @staticmethod |
| 106 | + cdef void _init_helper(_SyclEventRaw event, DPCTLSyclEventRef ERef): |
| 107 | + event._event_ref = ERef |
| 108 | + |
| 109 | + @staticmethod |
| 110 | + cdef SyclEventRaw _create(DPCTLSyclEventRef eref): |
| 111 | + cdef _SyclEventRaw ret = _SyclEventRaw.__new__(_SyclEventRaw) |
| 112 | + SyclEventRaw._init_helper(ret, eref) |
| 113 | + return SyclEventRaw(ret) |
| 114 | + |
| 115 | + cdef int _init_event_default(self): |
| 116 | + self._event_ref = DPCTLEvent_Create() |
| 117 | + if (self._event_ref is NULL): |
| 118 | + return -1 |
| 119 | + return 0 |
| 120 | + |
| 121 | + cdef int _init_event_from__SyclEventRaw(self, _SyclEventRaw other): |
| 122 | + self._event_ref = DPCTLEvent_Copy(other._event_ref) |
| 123 | + if (self._event_ref is NULL): |
| 124 | + return -1 |
| 125 | + return 0 |
| 126 | + |
| 127 | + cdef int _init_event_from_SyclEvent(self, SyclEvent event): |
| 128 | + self._event_ref = DPCTLEvent_Copy(event._event_ref) |
| 129 | + if (self._event_ref is NULL): |
| 130 | + return -1 |
| 131 | + return 0 |
| 132 | + |
| 133 | + cdef int _init_event_from_capsule(self, object cap): |
| 134 | + cdef DPCTLSyclEventRef ERef = NULL |
| 135 | + cdef DPCTLSyclEventRef ERef_copy = NULL |
| 136 | + cdef int ret = 0 |
| 137 | + if pycapsule.PyCapsule_IsValid(cap, "SyclEventRef"): |
| 138 | + ERef = <DPCTLSyclEventRef> pycapsule.PyCapsule_GetPointer( |
| 139 | + cap, "SyclEventRef" |
| 140 | + ) |
| 141 | + if (ERef is NULL): |
| 142 | + return -2 |
| 143 | + ret = pycapsule.PyCapsule_SetName(cap, "used_SyclEventRef") |
| 144 | + if (ret): |
| 145 | + return -2 |
| 146 | + ERef_copy = DPCTLEvent_Copy(ERef) |
| 147 | + if (ERef_copy is NULL): |
| 148 | + return -3 |
| 149 | + self._event_ref = ERef_copy |
| 150 | + return 0 |
| 151 | + else: |
| 152 | + return -128 |
| 153 | + |
| 154 | + def __cinit__(self, arg=None): |
| 155 | + cdef int ret = 0 |
| 156 | + if arg is None: |
| 157 | + ret = self._init_event_default() |
| 158 | + elif type(arg) is _SyclEventRaw: |
| 159 | + ret = self._init_event_from__SyclEventRaw(<_SyclEventRaw> arg) |
| 160 | + elif isinstance(arg, SyclEvent): |
| 161 | + ret = self._init_event_from_SyclEvent(<SyclEvent> arg) |
| 162 | + elif pycapsule.PyCapsule_IsValid(arg, "SyclEventRef"): |
| 163 | + ret = self._init_event_from_capsule(arg) |
| 164 | + else: |
| 165 | + raise TypeError( |
| 166 | + "Invalid argument." |
| 167 | + ) |
| 168 | + if (ret < 0): |
| 169 | + if (ret == -1): |
| 170 | + raise ValueError("Event failed to be created.") |
| 171 | + elif (ret == -2): |
| 172 | + raise TypeError( |
| 173 | + "Input capsule {} contains a null pointer or could not be" |
| 174 | + " renamed".format(arg) |
| 175 | + ) |
| 176 | + elif (ret == -3): |
| 177 | + raise ValueError( |
| 178 | + "Internal Error: Could not create a copy of a sycl event." |
| 179 | + ) |
| 180 | + raise ValueError( |
| 181 | + "Unrecognized error code ({}) encountered.".format(ret) |
| 182 | + ) |
| 183 | + |
| 184 | + cdef DPCTLSyclEventRef get_event_ref(self): |
| 185 | + """ Returns the `DPCTLSyclEventRef` pointer for this class. |
| 186 | + """ |
| 187 | + return self._event_ref |
| 188 | + |
| 189 | + cpdef void wait(self): |
| 190 | + DPCTLEvent_Wait(self._event_ref) |
| 191 | + |
| 192 | + def addressof_ref(self): |
| 193 | + """ Returns the address of the C API `DPCTLSyclEventRef` pointer as |
| 194 | + a size_t. |
| 195 | +
|
| 196 | + Returns: |
| 197 | + The address of the `DPCTLSyclEventRef` object used to create this |
| 198 | + `SyclEvent` cast to a size_t. |
| 199 | + """ |
| 200 | + return <size_t>self._event_ref |
| 201 | + |
| 202 | + def _get_capsule(self): |
| 203 | + cdef DPCTLSyclEventRef ERef = NULL |
| 204 | + ERef = DPCTLEvent_Copy(self._event_ref) |
| 205 | + if (ERef is NULL): |
| 206 | + raise ValueError("SyclEvent copy failed.") |
| 207 | + return pycapsule.PyCapsule_New( |
| 208 | + <void *>ERef, |
| 209 | + "SyclEventRef", |
| 210 | + &_event_capsule_deleter |
| 211 | + ) |
0 commit comments