Skip to content

Commit 1ded0a6

Browse files
committed
Add sycl event wait overload
1 parent 9e6c224 commit 1ded0a6

File tree

7 files changed

+64
-6
lines changed

7 files changed

+64
-6
lines changed

numba_dpex/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def parse_sem_version(version_string: str) -> Tuple[int, int, int]:
107107

108108
# Re-export all type names
109109
from numba_dpex.core.types import * # noqa E402
110+
from numba_dpex.dpctl_iface import _intrinsic # noqa E402
110111
from numba_dpex.dpnp_iface import dpnpimpl # noqa E402
111112

112113
if config.HAS_NON_HOST_DEVICE:

numba_dpex/core/datamodel/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def _init_data_model_manager() -> datamodel.DataModelManager:
280280

281281
# Register the DpctlSyclEvent type
282282
register_model(DpctlSyclEvent)(SyclEventModel)
283+
283284
# Register the RangeType type
284285
register_model(RangeType)(RangeModel)
285286

numba_dpex/core/types/dpctl_types.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,7 @@ def box_sycl_queue(typ, val, c):
123123
class DpctlSyclEvent(types.Type):
124124
"""A Numba type to represent a dpctl.SyclEvent PyObject."""
125125

126-
def __init__(self, sycl_event):
127-
if not isinstance(sycl_event, SyclEvent):
128-
raise TypeError("The argument sycl_event is not of type SyclEvent.")
129-
126+
def __init__(self):
130127
super(DpctlSyclEvent, self).__init__(name="DpctlSyclEvent")
131128

132129
@property

numba_dpex/core/typing/typeof.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def typeof_dpctl_sycl_event(val, c):
121121
122122
Returns: A numba_dpex.core.types.dpctl_types.DpctlSyclEvent instance.
123123
"""
124-
return DpctlSyclEvent(val)
124+
return DpctlSyclEvent()
125125

126126

127127
@typeof_impl.register(Range)

numba_dpex/dpctl_iface/_intrinsic.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from numba import types
6+
from numba.core.datamodel import default_manager
7+
from numba.extending import intrinsic, overload_method
8+
9+
import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl
10+
from numba_dpex.core import types as dpex_types
11+
12+
13+
@intrinsic
14+
def sycl_event_wait(typingctx, ty_event: dpex_types.DpctlSyclEvent):
15+
sig = types.void(dpex_types.DpctlSyclEvent())
16+
17+
# defines the custom code generation
18+
def codegen(context, builder, signature, args):
19+
sycl_event_dm = default_manager.lookup(ty_event)
20+
event_ref = builder.extract_value(
21+
args[0],
22+
sycl_event_dm.get_field_position("event_ref"),
23+
)
24+
25+
sycl.dpctl_event_wait(builder, event_ref)
26+
27+
return sig, codegen
28+
29+
30+
@overload_method(dpex_types.DpctlSyclEvent, "wait")
31+
def ol_dpctl_sycl_event_wait(
32+
event,
33+
):
34+
"""Implementation of an overload to support dpctl.SyclEvent() inside
35+
a dpjit function.
36+
"""
37+
return lambda event: sycl_event_wait(event)
38+
39+
40+
# We don't want user to call sycl_event_wait(event), instead it must be called
41+
# with event.wait(). In that way we guarantee the argument type by the
42+
# @overload_method.
43+
__all__ = []

numba_dpex/tests/core/types/DpctlSyclEvent/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_model_for_DpctlSyclEvent():
1717
"""Test the data model for DpctlSyclEvent that is registered with numba's
1818
default data model manager.
1919
"""
20-
sycl_event = DpctlSyclEvent(dpctl.SyclEvent())
20+
sycl_event = DpctlSyclEvent()
2121
default_model = default_manager.lookup(sycl_event)
2222
assert isinstance(default_model, SyclEventModel)
2323

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import dpctl
2+
3+
from numba_dpex import dpjit
4+
5+
6+
@dpjit
7+
def wait_call(a):
8+
a.wait()
9+
return None
10+
11+
12+
def test_wait_DpctlSyclEvent():
13+
"""Test the dpctl.SyclEvent.wait() call overload."""
14+
15+
e = dpctl.SyclEvent()
16+
wait_call(e)

0 commit comments

Comments
 (0)