12
12
from llvmlite import ir as llvmir
13
13
from numba .core import cgutils , types
14
14
from numba .core .cpu import CPUContext
15
- from numba .core .types .containers import UniTuple
15
+ from numba .core .types .containers import Tuple , UniTuple
16
16
from numba .core .types .functions import Dispatcher
17
17
from numba .extending import intrinsic
18
18
@@ -51,13 +51,15 @@ def _submit_kernel_async(
51
51
typingctx ,
52
52
ty_kernel_fn : Dispatcher ,
53
53
ty_index_space : Union [RangeType , NdRangeType ],
54
+ ty_dependent_events : UniTuple ,
54
55
ty_kernel_args_tuple : UniTuple ,
55
56
):
56
57
"""Generates IR code for call_kernel_async dpjit function."""
57
58
return _submit_kernel (
58
59
typingctx ,
59
60
ty_kernel_fn ,
60
61
ty_index_space ,
62
+ ty_dependent_events ,
61
63
ty_kernel_args_tuple ,
62
64
sync = False ,
63
65
)
@@ -75,15 +77,17 @@ def _submit_kernel_sync(
75
77
typingctx ,
76
78
ty_kernel_fn ,
77
79
ty_index_space ,
80
+ None ,
78
81
ty_kernel_args_tuple ,
79
82
sync = True ,
80
83
)
81
84
82
85
83
- def _submit_kernel (
84
- typingctx , # pylint: disable=W0613
86
+ def _submit_kernel ( # pylint: disable=too-many-arguments
87
+ typingctx , # pylint: disable=unused-argument
85
88
ty_kernel_fn : Dispatcher ,
86
89
ty_index_space : Union [RangeType , NdRangeType ],
90
+ ty_dependent_events : UniTuple ,
87
91
ty_kernel_args_tuple : UniTuple ,
88
92
sync : bool ,
89
93
):
@@ -106,7 +110,21 @@ def _submit_kernel(
106
110
ty_event = DpctlSyclEvent ()
107
111
ty_return = types .Tuple ([ty_event , ty_event ])
108
112
109
- sig = ty_return (ty_kernel_fn , ty_index_space , ty_kernel_args_tuple )
113
+ if ty_dependent_events is not None :
114
+ if not isinstance (ty_dependent_events , UniTuple ) and not isinstance (
115
+ ty_dependent_events , Tuple
116
+ ):
117
+ raise ValueError ("dependent events must be passed as a tuple" )
118
+
119
+ sig = ty_return (
120
+ ty_kernel_fn ,
121
+ ty_index_space ,
122
+ ty_dependent_events ,
123
+ ty_kernel_args_tuple ,
124
+ )
125
+ else :
126
+ sig = ty_return (ty_kernel_fn , ty_index_space , ty_kernel_args_tuple )
127
+
110
128
kernel_sig = types .void (* ty_kernel_args_tuple )
111
129
# ty_kernel_fn is type specific to exact function, so we can get function
112
130
# directly from type and compile it. Thats why we don't need to get it in
@@ -123,8 +141,14 @@ def codegen(
123
141
):
124
142
ty_index_space : Union [RangeType , NdRangeType ] = sig .args [1 ]
125
143
ll_index_space : llvmir .Instruction = llargs [1 ]
126
- ty_kernel_args_tuple : UniTuple = sig .args [2 ]
127
- ll_kernel_args_tuple : llvmir .Instruction = llargs [2 ]
144
+ ty_kernel_args_tuple : UniTuple = sig .args [- 1 ]
145
+ ll_kernel_args_tuple : llvmir .Instruction = llargs [- 1 ]
146
+
147
+ if len (llargs ) == 4 :
148
+ ty_dependent_events : UniTuple = sig .args [2 ]
149
+ ll_dependent_events : llvmir .Instruction = llargs [2 ]
150
+ else :
151
+ ty_dependent_events = None
128
152
129
153
kl_builder = kl .KernelLaunchIRBuilder (
130
154
cgctx ,
@@ -140,7 +164,13 @@ def codegen(
140
164
)
141
165
kl_builder .set_queue_from_arguments ()
142
166
kl_builder .set_kernel_from_spirv (kernel_module )
143
- kl_builder .set_dependant_event_list ([])
167
+ if ty_dependent_events is None :
168
+ kl_builder .set_dependent_events ([])
169
+ else :
170
+ kl_builder .set_dependent_events_from_tuple (
171
+ ty_dependent_events ,
172
+ ll_dependent_events ,
173
+ )
144
174
device_event_ref = kl_builder .submit ()
145
175
146
176
if not sync :
@@ -185,7 +215,10 @@ def call_kernel(kernel_fn, index_space, *kernel_args) -> None:
185
215
186
216
@dpjit
187
217
def call_kernel_async (
188
- kernel_fn , index_space , * kernel_args
218
+ kernel_fn ,
219
+ index_space ,
220
+ dependent_events : list [dpctl .SyclEvent ],
221
+ * kernel_args
189
222
) -> tuple [dpctl .SyclEvent , dpctl .SyclEvent ]:
190
223
"""Calls a numba_dpex.kernel decorated function from CPython or from another
191
224
dpjit function. Kernel execution happens in asyncronous way, so the thread
@@ -210,5 +243,6 @@ def call_kernel_async(
210
243
return _submit_kernel_async ( # pylint: disable=E1120
211
244
kernel_fn ,
212
245
index_space ,
246
+ dependent_events ,
213
247
kernel_args ,
214
248
)
0 commit comments