Skip to content

Commit 3cfa80b

Browse files
committed
Add dependant event support to async kernel submission
1 parent 070df7d commit 3cfa80b

File tree

5 files changed

+133
-24
lines changed

5 files changed

+133
-24
lines changed

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _submit_parfor_kernel(
185185
kl_builder.set_arguments(
186186
kernel_fn.kernel_arg_types, kernel_args=kernel_args
187187
)
188-
kl_builder.set_dependant_event_list([])
188+
kl_builder.set_dependent_events([])
189189
event_ref = kl_builder.submit()
190190

191191
sycl.dpctl_event_wait(lowerer.builder, event_ref)

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -668,22 +668,42 @@ def set_arguments_form_tuple(
668668
kernel_args = self._extract_llvm_values_from_tuple(ll_kernel_args_tuple)
669669
self.set_arguments(ty_kernel_args_tuple, kernel_args)
670670

671-
def set_dependant_event_list(self, dep_events: list[llvmir.Instruction]):
672-
"""Sets dependant events to the argument list."""
673-
if self.arguments.dep_events is not None:
674-
return
671+
def set_dependent_events(self, dep_events: list[llvmir.Instruction]):
672+
"""Sets dependent events to the argument list."""
673+
ll_dep_events = self._create_ll_from_py_list(types.voidptr, dep_events)
674+
self.arguments.dep_events = ll_dep_events
675+
self.arguments.dep_events_len = self.context.get_constant(
676+
types.uintp, len(dep_events)
677+
)
675678

676-
if len(dep_events) > 0:
677-
# TODO: implement for non zero input
678-
raise NotImplementedError
679+
def set_dependent_events_from_tuple(
680+
self,
681+
ty_dependent_events: UniTuple,
682+
ll_dependent_events: llvmir.Instruction,
683+
):
684+
"""Set's dependent events from tuple represented by LLVM IR.
679685
680-
self.arguments.dep_events = self.builder.bitcast(
681-
utils.create_null_ptr(builder=self.builder, context=self.context),
682-
utils.get_llvm_type(context=self.context, type=types.voidptr),
683-
)
684-
self.arguments.dep_events_len = self.context.get_constant(
685-
types.uintp, 0
686+
Args:
687+
ll_dependent_events: tuple of numba's data models.
688+
"""
689+
if len(ty_dependent_events) == 0:
690+
self.set_dependent_events([])
691+
return
692+
693+
ty_event = ty_dependent_events[0]
694+
dm_dependent_events = self._extract_llvm_values_from_tuple(
695+
ll_dependent_events
686696
)
697+
dependent_events = []
698+
for dm_dependent_event in dm_dependent_events:
699+
event_struct_proxy = cgutils.create_struct_proxy(ty_event)(
700+
self.context,
701+
self.builder,
702+
value=dm_dependent_event,
703+
)
704+
dependent_events.append(event_struct_proxy.event_ref)
705+
706+
self.set_dependent_events(dependent_events)
687707

688708
def submit(self) -> llvmir.Instruction:
689709
"""Submits kernel by calling sycl.dpctl_queue_submit_range or

numba_dpex/dpctl_iface/libsyclinterface_bindings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def dpctl_queue_submit_range(builder: llvmir.IRBuilder, *args):
154154
llvmir.IntType(64),
155155
llvmir.IntType(64).as_pointer(),
156156
llvmir.IntType(64),
157-
cgutils.voidptr_t,
157+
cgutils.voidptr_t.as_pointer(),
158158
llvmir.IntType(64),
159159
],
160160
func_name="DPCTLQueue_SubmitRange",
@@ -195,7 +195,7 @@ def dpctl_queue_submit_ndrange(builder: llvmir.IRBuilder, *args):
195195
llvmir.IntType(64).as_pointer(),
196196
llvmir.IntType(64).as_pointer(),
197197
llvmir.IntType(64),
198-
cgutils.voidptr_t,
198+
cgutils.voidptr_t.as_pointer(),
199199
llvmir.IntType(64),
200200
],
201201
func_name="DPCTLQueue_SubmitNDRange",

numba_dpex/experimental/launcher.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from llvmlite import ir as llvmir
1313
from numba.core import cgutils, types
1414
from numba.core.cpu import CPUContext
15-
from numba.core.types.containers import UniTuple
15+
from numba.core.types.containers import Tuple, UniTuple
1616
from numba.core.types.functions import Dispatcher
1717
from numba.extending import intrinsic
1818

@@ -51,13 +51,15 @@ def _submit_kernel_async(
5151
typingctx,
5252
ty_kernel_fn: Dispatcher,
5353
ty_index_space: Union[RangeType, NdRangeType],
54+
ty_dependent_events: UniTuple,
5455
ty_kernel_args_tuple: UniTuple,
5556
):
5657
"""Generates IR code for call_kernel_async dpjit function."""
5758
return _submit_kernel(
5859
typingctx,
5960
ty_kernel_fn,
6061
ty_index_space,
62+
ty_dependent_events,
6163
ty_kernel_args_tuple,
6264
sync=False,
6365
)
@@ -75,15 +77,17 @@ def _submit_kernel_sync(
7577
typingctx,
7678
ty_kernel_fn,
7779
ty_index_space,
80+
None,
7881
ty_kernel_args_tuple,
7982
sync=True,
8083
)
8184

8285

83-
def _submit_kernel(
84-
typingctx, # pylint: disable=W0613
86+
def _submit_kernel( # pylint: disable=too-many-arguments
87+
typingctx, # pylint: disable=unused-argument
8588
ty_kernel_fn: Dispatcher,
8689
ty_index_space: Union[RangeType, NdRangeType],
90+
ty_dependent_events: UniTuple,
8791
ty_kernel_args_tuple: UniTuple,
8892
sync: bool,
8993
):
@@ -106,7 +110,21 @@ def _submit_kernel(
106110
ty_event = DpctlSyclEvent()
107111
ty_return = types.Tuple([ty_event, ty_event])
108112

109-
sig = ty_return(ty_kernel_fn, ty_index_space, ty_kernel_args_tuple)
113+
if ty_dependent_events is not None:
114+
if not isinstance(ty_dependent_events, UniTuple) and not isinstance(
115+
ty_dependent_events, Tuple
116+
):
117+
raise ValueError("dependent events must be passed as a tuple")
118+
119+
sig = ty_return(
120+
ty_kernel_fn,
121+
ty_index_space,
122+
ty_dependent_events,
123+
ty_kernel_args_tuple,
124+
)
125+
else:
126+
sig = ty_return(ty_kernel_fn, ty_index_space, ty_kernel_args_tuple)
127+
110128
kernel_sig = types.void(*ty_kernel_args_tuple)
111129
# ty_kernel_fn is type specific to exact function, so we can get function
112130
# directly from type and compile it. Thats why we don't need to get it in
@@ -123,8 +141,14 @@ def codegen(
123141
):
124142
ty_index_space: Union[RangeType, NdRangeType] = sig.args[1]
125143
ll_index_space: llvmir.Instruction = llargs[1]
126-
ty_kernel_args_tuple: UniTuple = sig.args[2]
127-
ll_kernel_args_tuple: llvmir.Instruction = llargs[2]
144+
ty_kernel_args_tuple: UniTuple = sig.args[-1]
145+
ll_kernel_args_tuple: llvmir.Instruction = llargs[-1]
146+
147+
if len(llargs) == 4:
148+
ty_dependent_events: UniTuple = sig.args[2]
149+
ll_dependent_events: llvmir.Instruction = llargs[2]
150+
else:
151+
ty_dependent_events = None
128152

129153
kl_builder = kl.KernelLaunchIRBuilder(
130154
cgctx,
@@ -140,7 +164,13 @@ def codegen(
140164
)
141165
kl_builder.set_queue_from_arguments()
142166
kl_builder.set_kernel_from_spirv(kernel_module)
143-
kl_builder.set_dependant_event_list([])
167+
if ty_dependent_events is None:
168+
kl_builder.set_dependent_events([])
169+
else:
170+
kl_builder.set_dependent_events_from_tuple(
171+
ty_dependent_events,
172+
ll_dependent_events,
173+
)
144174
device_event_ref = kl_builder.submit()
145175

146176
if not sync:
@@ -185,7 +215,10 @@ def call_kernel(kernel_fn, index_space, *kernel_args) -> None:
185215

186216
@dpjit
187217
def call_kernel_async(
188-
kernel_fn, index_space, *kernel_args
218+
kernel_fn,
219+
index_space,
220+
dependent_events: list[dpctl.SyclEvent],
221+
*kernel_args
189222
) -> tuple[dpctl.SyclEvent, dpctl.SyclEvent]:
190223
"""Calls a numba_dpex.kernel decorated function from CPython or from another
191224
dpjit function. Kernel execution happens in asyncronous way, so the thread
@@ -210,5 +243,6 @@ def call_kernel_async(
210243
return _submit_kernel_async( # pylint: disable=E1120
211244
kernel_fn,
212245
index_space,
246+
dependent_events,
213247
kernel_args,
214248
)

numba_dpex/tests/experimental/test_async_kernel.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import dpctl
66
import dpnp
7+
import pytest
8+
from numba.core.errors import TypingError
79

810
import numba_dpex as dpex
911
import numba_dpex.experimental as exp_dpex
@@ -33,6 +35,7 @@ def test_async_add():
3335
host_ref, event_ref = exp_dpex.call_kernel_async(
3436
add,
3537
r,
38+
(),
3639
a,
3740
b,
3841
c,
@@ -50,6 +53,58 @@ def test_async_add():
5053
assert dpnp.array_equal(c, d)
5154

5255

56+
def test_async_dependent_add_list_exception():
57+
"""Checks either ValueError is triggered if list was passed instead of
58+
tuple for the dependent events."""
59+
size = 10
60+
61+
# TODO: should capture ValueError, but numba captures it and generates
62+
# TypingError. ValueError is still readable there.
63+
with pytest.raises(TypingError):
64+
exp_dpex.call_kernel_async(
65+
add,
66+
Range(size),
67+
[dpctl.SyclEvent()],
68+
dpnp.ones(size),
69+
dpnp.ones(size),
70+
dpnp.ones(size),
71+
)
72+
73+
74+
def test_async_dependent_add():
75+
size = 10
76+
a = dpnp.ones(size)
77+
b = dpnp.ones(size)
78+
c = dpnp.zeros(size)
79+
80+
r = Range(size)
81+
82+
host_ref, event_ref = exp_dpex.call_kernel_async(
83+
add,
84+
r,
85+
(),
86+
a,
87+
b,
88+
c,
89+
)
90+
91+
host2_ref, event2_ref = exp_dpex.call_kernel_async(
92+
add,
93+
r,
94+
(event_ref,),
95+
a,
96+
c,
97+
b,
98+
)
99+
100+
event2_ref.wait()
101+
d = dpnp.ones(size) * 3
102+
assert dpnp.array_equal(b, d)
103+
104+
host_ref.wait()
105+
host2_ref.wait()
106+
107+
53108
def test_async_add_from_cache():
54109
test_async_add() # compile
55110
old_size = testing.kernel_cache_size()

0 commit comments

Comments
 (0)