22
22
23
23
from ..exceptions import UnsupportedParforError
24
24
from ..types .dpnp_ndarray_type import DpnpNdArray
25
- from .kernel_builder import create_kernel_for_parfor
25
+ from .kernel_builder import ParforKernel , create_kernel_for_parfor
26
26
from .reduction_kernel_builder import (
27
27
create_reduction_main_kernel_for_parfor ,
28
28
create_reduction_remainder_kernel_for_parfor ,
@@ -91,7 +91,9 @@ class ParforLowerImpl:
91
91
for a parfor and submits it to a queue.
92
92
"""
93
93
94
- def _build_kernel_arglist (self , kernel_fn , lowerer , kernel_builder ):
94
+ def _build_kernel_arglist (
95
+ self , kernel_fn , lowerer , kernel_builder : KernelLaunchIRBuilder
96
+ ):
95
97
"""Creates local variables for all the arguments and the argument types
96
98
that are passes to the kernel function.
97
99
@@ -139,27 +141,15 @@ def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder):
139
141
arg_types = args_ty_list ,
140
142
)
141
143
142
- def _submit_parfor_kernel (
144
+ def _loop_ranges (
143
145
self ,
144
146
lowerer ,
145
- kernel_fn ,
146
147
loop_ranges ,
147
148
):
148
- """
149
- Adds a call to submit a kernel function into the function body of the
150
- current Numba JIT compiled function.
151
- """
152
- # Ensure that the Python arguments are kept alive for the duration of
153
- # the kernel execution
154
- keep_alive_kernels .append (kernel_fn .kernel )
155
- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
156
-
157
- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
158
- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
159
-
160
149
# Create a global range over which to submit the kernel based on the
161
150
# loop_ranges of the parfor
162
151
global_range = []
152
+
163
153
# SYCL ranges can have at max 3 dimension. If the parfor is of a higher
164
154
# dimension then the indexing for the higher dimensions is done inside
165
155
# the kernel.
@@ -176,45 +166,13 @@ def _submit_parfor_kernel(
176
166
177
167
local_range = []
178
168
179
- kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
180
- kernel_ref = lowerer .builder .inttoptr (
181
- lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
182
- cgutils .voidptr_t ,
183
- )
184
- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
185
-
186
- # Submit a synchronous kernel
187
- kernel_builder .submit_sycl_kernel (
188
- sycl_kernel_ref = kernel_ref ,
189
- sycl_queue_ref = curr_queue_ref ,
190
- total_kernel_args = args .num_flattened_args ,
191
- arg_list = args .arg_vals ,
192
- arg_ty_list = args .arg_types ,
193
- global_range = global_range ,
194
- local_range = local_range ,
195
- )
196
-
197
- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
198
- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
169
+ return global_range , local_range
199
170
200
- def _submit_reduction_main_parfor_kernel (
171
+ def _reduction_ranges (
201
172
self ,
202
173
lowerer ,
203
- kernel_fn ,
204
174
reductionHelper = None ,
205
175
):
206
- """
207
- Adds a call to submit the main kernel of a parfor reduction into the
208
- function body of the current Numba JIT compiled function.
209
- """
210
- # Ensure that the Python arguments are kept alive for the duration of
211
- # the kernel execution
212
- keep_alive_kernels .append (kernel_fn .kernel )
213
- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
214
-
215
- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
216
-
217
- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
218
176
# Create a global range over which to submit the kernel based on the
219
177
# loop_ranges of the parfor
220
178
global_range = []
@@ -228,54 +186,39 @@ def _submit_reduction_main_parfor_kernel(
228
186
_load_range (lowerer , reductionHelper .work_group_size )
229
187
)
230
188
231
- kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
232
- kernel_ref = lowerer .builder .inttoptr (
233
- lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
234
- cgutils .voidptr_t ,
235
- )
236
- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
189
+ return global_range , local_range
237
190
238
- # Submit a synchronous kernel
239
- kernel_builder .submit_sycl_kernel (
240
- sycl_kernel_ref = kernel_ref ,
241
- sycl_queue_ref = curr_queue_ref ,
242
- total_kernel_args = args .num_flattened_args ,
243
- arg_list = args .arg_vals ,
244
- arg_ty_list = args .arg_types ,
245
- global_range = global_range ,
246
- local_range = local_range ,
247
- )
191
+ def _remainder_ranges (self , lowerer ):
192
+ # Create a global range over which to submit the kernel based on the
193
+ # loop_ranges of the parfor
194
+ global_range = []
248
195
249
- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
250
- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
196
+ stop = _load_range (lowerer , 1 )
197
+
198
+ global_range .append (stop )
251
199
252
- def _submit_reduction_remainder_parfor_kernel (
200
+ local_range = []
201
+
202
+ return global_range , local_range
203
+
204
+ def _submit_parfor_kernel (
253
205
self ,
254
206
lowerer ,
255
- kernel_fn ,
207
+ kernel_fn : ParforKernel ,
208
+ global_range ,
209
+ local_range ,
256
210
):
257
211
"""
258
- Adds a call to submit the remainder kernel of a parfor reduction into
259
- the function body of the current Numba JIT compiled function.
212
+ Adds a call to submit a kernel function into the function body of the
213
+ current Numba JIT compiled function.
260
214
"""
261
215
# Ensure that the Python arguments are kept alive for the duration of
262
216
# the kernel execution
263
217
keep_alive_kernels .append (kernel_fn .kernel )
264
-
265
218
kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
266
219
267
220
ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
268
-
269
221
args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
270
- # Create a global range over which to submit the kernel based on the
271
- # loop_ranges of the parfor
272
- global_range = []
273
-
274
- stop = _load_range (lowerer , 1 )
275
-
276
- global_range .append (stop )
277
-
278
- local_range = []
279
222
280
223
kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
281
224
kernel_ref = lowerer .builder .inttoptr (
@@ -360,10 +303,15 @@ def _reduction_codegen(
360
303
parfor_reddict ,
361
304
)
362
305
363
- self ._submit_reduction_main_parfor_kernel (
306
+ global_range , local_range = self ._reduction_ranges (
307
+ lowerer , reductionHelperList [0 ]
308
+ )
309
+
310
+ self ._submit_parfor_kernel (
364
311
lowerer ,
365
312
parfor_kernel ,
366
- reductionHelperList [0 ],
313
+ global_range ,
314
+ local_range ,
367
315
)
368
316
369
317
parfor_kernel = create_reduction_remainder_kernel_for_parfor (
@@ -376,9 +324,13 @@ def _reduction_codegen(
376
324
reductionHelperList ,
377
325
)
378
326
379
- self ._submit_reduction_remainder_parfor_kernel (
327
+ global_range , local_range = self ._remainder_ranges (lowerer )
328
+
329
+ self ._submit_parfor_kernel (
380
330
lowerer ,
381
331
parfor_kernel ,
332
+ global_range ,
333
+ local_range ,
382
334
)
383
335
384
336
reductionKernelVar .copy_final_sum_to_host (parfor_kernel )
@@ -492,11 +444,14 @@ def _lower_parfor_as_kernel(self, lowerer, parfor):
492
444
# FIXME: Make the exception more informative
493
445
raise UnsupportedParforError
494
446
447
+ global_range , local_range = self ._loop_ranges (lowerer , loop_ranges )
448
+
495
449
# Finally submit the kernel
496
450
self ._submit_parfor_kernel (
497
451
lowerer ,
498
452
parfor_kernel ,
499
- loop_ranges ,
453
+ global_range ,
454
+ local_range ,
500
455
)
501
456
502
457
# TODO: free the kernel at this point
0 commit comments