13
13
)
14
14
15
15
from numba_dpex import config
16
- from numba_dpex .core .datamodel .models import dpex_data_model_manager as dpex_dmm
17
16
from numba_dpex .core .parfors .reduction_helper import (
18
17
ReductionHelper ,
19
18
ReductionKernelVariables ,
20
19
)
21
20
from numba_dpex .core .utils .kernel_launcher import KernelLaunchIRBuilder
21
+ from numba_dpex .dpctl_iface import libsyclinterface_bindings as sycl
22
+ from numba_dpex .core .datamodel .models import (
23
+ dpex_data_model_manager as kernel_dmm ,
24
+ )
22
25
23
26
from ..exceptions import UnsupportedParforError
24
27
from ..types .dpnp_ndarray_type import DpnpNdArray
25
- from .kernel_builder import create_kernel_for_parfor
28
+ from .kernel_builder import ParforKernel , create_kernel_for_parfor
26
29
from .reduction_kernel_builder import (
27
30
create_reduction_main_kernel_for_parfor ,
28
31
create_reduction_remainder_kernel_for_parfor ,
29
32
)
30
33
31
- _KernelArgs = namedtuple (
32
- "_KernelArgs" ,
33
- ["num_flattened_args" , "arg_vals" , "arg_types" ],
34
- )
35
-
36
34
37
35
# A global list of kernels to keep the objects alive indefinitely.
38
36
keep_alive_kernels = []
@@ -68,11 +66,8 @@ def _getvar(lowerer, x):
68
66
var_val = lowerer .varmap [x ]
69
67
70
68
if var_val :
71
- if not isinstance (var_val .type , llvmir .PointerType ):
72
- with lowerer .builder .goto_entry_block ():
73
- var_val_ptr = lowerer .builder .alloca (var_val .type )
74
- lowerer .builder .store (var_val , var_val_ptr )
75
- return var_val_ptr
69
+ if isinstance (var_val .type , llvmir .PointerType ):
70
+ return lowerer .builder .load (var_val )
76
71
else :
77
72
return var_val
78
73
else :
@@ -91,75 +86,15 @@ class ParforLowerImpl:
91
86
for a parfor and submits it to a queue.
92
87
"""
93
88
94
- def _build_kernel_arglist (self , kernel_fn , lowerer , kernel_builder ):
95
- """Creates local variables for all the arguments and the argument types
96
- that are passes to the kernel function.
97
-
98
- Args:
99
- kernel_fn: Kernel function to be launched.
100
- lowerer: The Numba lowerer used to generate the LLVM IR
101
-
102
- Raises:
103
- AssertionError: If the LLVM IR Value for an argument defined in
104
- Numba IR is not found.
105
- """
106
- num_flattened_args = 0
107
-
108
- # Compute number of args to be passed to the kernel. Note that the
109
- # actual number of kernel arguments is greater than the count of
110
- # kernel_fn.kernel_args as arrays get flattened.
111
- for arg_type in kernel_fn .kernel_arg_types :
112
- if isinstance (arg_type , DpnpNdArray ):
113
- datamodel = dpex_dmm .lookup (arg_type )
114
- num_flattened_args += datamodel .flattened_field_count
115
- elif arg_type == types .complex64 or arg_type == types .complex128 :
116
- num_flattened_args += 2
117
- else :
118
- num_flattened_args += 1
119
-
120
- # Create LLVM values for the kernel args list and kernel arg types list
121
- args_list = kernel_builder .allocate_kernel_arg_array (num_flattened_args )
122
- args_ty_list = kernel_builder .allocate_kernel_arg_ty_array (
123
- num_flattened_args
124
- )
125
- callargs_ptrs = []
126
- for arg in kernel_fn .kernel_args :
127
- callargs_ptrs .append (_getvar (lowerer , arg ))
128
-
129
- kernel_builder .populate_kernel_args_and_args_ty_arrays (
130
- kernel_argtys = kernel_fn .kernel_arg_types ,
131
- callargs_ptrs = callargs_ptrs ,
132
- args_list = args_list ,
133
- args_ty_list = args_ty_list ,
134
- )
135
-
136
- return _KernelArgs (
137
- num_flattened_args = num_flattened_args ,
138
- arg_vals = args_list ,
139
- arg_types = args_ty_list ,
140
- )
141
-
142
- def _submit_parfor_kernel (
89
+ def _loop_ranges (
143
90
self ,
144
91
lowerer ,
145
- kernel_fn ,
146
92
loop_ranges ,
147
93
):
148
- """
149
- Adds a call to submit a kernel function into the function body of the
150
- current Numba JIT compiled function.
151
- """
152
- # Ensure that the Python arguments are kept alive for the duration of
153
- # the kernel execution
154
- keep_alive_kernels .append (kernel_fn .kernel )
155
- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
156
-
157
- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
158
- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
159
-
160
94
# Create a global range over which to submit the kernel based on the
161
95
# loop_ranges of the parfor
162
96
global_range = []
97
+
163
98
# SYCL ranges can have at max 3 dimension. If the parfor is of a higher
164
99
# dimension then the indexing for the higher dimensions is done inside
165
100
# the kernel.
@@ -173,48 +108,19 @@ def _submit_parfor_kernel(
173
108
"non-unit strides are not yet supported."
174
109
)
175
110
global_range .append (stop )
176
-
111
+ # For now the local_range is always an empty list as numba_dpex always
112
+ # submits kernels generated for parfor nodes as range kernels.
113
+ # The provision is kept here if in future there is newer functionality
114
+ # to submit these kernels as ndrange.
177
115
local_range = []
178
116
179
- kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
180
- kernel_ref = lowerer .builder .inttoptr (
181
- lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
182
- cgutils .voidptr_t ,
183
- )
184
- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
185
-
186
- # Submit a synchronous kernel
187
- kernel_builder .submit_sycl_kernel (
188
- sycl_kernel_ref = kernel_ref ,
189
- sycl_queue_ref = curr_queue_ref ,
190
- total_kernel_args = args .num_flattened_args ,
191
- arg_list = args .arg_vals ,
192
- arg_ty_list = args .arg_types ,
193
- global_range = global_range ,
194
- local_range = local_range ,
195
- )
196
-
197
- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
198
- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
117
+ return global_range , local_range
199
118
200
- def _submit_reduction_main_parfor_kernel (
119
+ def _reduction_ranges (
201
120
self ,
202
121
lowerer ,
203
- kernel_fn ,
204
122
reductionHelper = None ,
205
123
):
206
- """
207
- Adds a call to submit the main kernel of a parfor reduction into the
208
- function body of the current Numba JIT compiled function.
209
- """
210
- # Ensure that the Python arguments are kept alive for the duration of
211
- # the kernel execution
212
- keep_alive_kernels .append (kernel_fn .kernel )
213
- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
214
-
215
- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
216
-
217
- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
218
124
# Create a global range over which to submit the kernel based on the
219
125
# loop_ranges of the parfor
220
126
global_range = []
@@ -228,75 +134,63 @@ def _submit_reduction_main_parfor_kernel(
228
134
_load_range (lowerer , reductionHelper .work_group_size )
229
135
)
230
136
231
- kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
232
- kernel_ref = lowerer .builder .inttoptr (
233
- lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
234
- cgutils .voidptr_t ,
235
- )
236
- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
237
-
238
- # Submit a synchronous kernel
239
- kernel_builder .submit_sycl_kernel (
240
- sycl_kernel_ref = kernel_ref ,
241
- sycl_queue_ref = curr_queue_ref ,
242
- total_kernel_args = args .num_flattened_args ,
243
- arg_list = args .arg_vals ,
244
- arg_ty_list = args .arg_types ,
245
- global_range = global_range ,
246
- local_range = local_range ,
247
- )
137
+ return global_range , local_range
248
138
249
- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
250
- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
139
+ def _remainder_ranges (self , lowerer ):
140
+ # Create a global range over which to submit the kernel based on the
141
+ # loop_ranges of the parfor
142
+ global_range = []
251
143
252
- def _submit_reduction_remainder_parfor_kernel (
144
+ stop = _load_range (lowerer , 1 )
145
+
146
+ global_range .append (stop )
147
+
148
+ local_range = []
149
+
150
+ return global_range , local_range
151
+
152
+ def _submit_parfor_kernel (
253
153
self ,
254
154
lowerer ,
255
- kernel_fn ,
155
+ kernel_fn : ParforKernel ,
156
+ global_range ,
157
+ local_range ,
256
158
):
257
159
"""
258
- Adds a call to submit the remainder kernel of a parfor reduction into
259
- the function body of the current Numba JIT compiled function.
160
+ Adds a call to submit a kernel function into the function body of the
161
+ current Numba JIT compiled function.
260
162
"""
261
163
# Ensure that the Python arguments are kept alive for the duration of
262
164
# the kernel execution
263
165
keep_alive_kernels .append (kernel_fn .kernel )
166
+ kl_builder = KernelLaunchIRBuilder (
167
+ lowerer .context , lowerer .builder , kernel_dmm
168
+ )
264
169
265
- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
266
-
267
- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
268
-
269
- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
270
- # Create a global range over which to submit the kernel based on the
271
- # loop_ranges of the parfor
272
- global_range = []
273
-
274
- stop = _load_range (lowerer , 1 )
170
+ queue_ref = kl_builder .get_queue (exec_queue = kernel_fn .queue )
275
171
276
- global_range . append ( stop )
277
-
278
- local_range = []
172
+ kernel_args = []
173
+ for arg in kernel_fn . kernel_args :
174
+ kernel_args . append ( _getvar ( lowerer , arg ))
279
175
280
176
kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
281
177
kernel_ref = lowerer .builder .inttoptr (
282
178
lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
283
179
cgutils .voidptr_t ,
284
180
)
285
- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
286
-
287
- # Submit a synchronous kernel
288
- kernel_builder .submit_sycl_kernel (
289
- sycl_kernel_ref = kernel_ref ,
290
- sycl_queue_ref = curr_queue_ref ,
291
- total_kernel_args = args .num_flattened_args ,
292
- arg_list = args .arg_vals ,
293
- arg_ty_list = args .arg_types ,
294
- global_range = global_range ,
295
- local_range = local_range ,
181
+
182
+ kl_builder .set_kernel (kernel_ref )
183
+ kl_builder .set_queue (queue_ref )
184
+ kl_builder .set_range (global_range , local_range )
185
+ kl_builder .set_arguments (
186
+ kernel_fn .kernel_arg_types , kernel_args = kernel_args
296
187
)
188
+ kl_builder .set_dependant_event_list (dep_events = [])
189
+ event_ref = kl_builder .submit ()
297
190
298
- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
299
- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
191
+ sycl .dpctl_event_wait (lowerer .builder , event_ref )
192
+ sycl .dpctl_event_delete (lowerer .builder , event_ref )
193
+ sycl .dpctl_queue_delete (lowerer .builder , queue_ref )
300
194
301
195
def _reduction_codegen (
302
196
self ,
@@ -360,10 +254,15 @@ def _reduction_codegen(
360
254
parfor_reddict ,
361
255
)
362
256
363
- self ._submit_reduction_main_parfor_kernel (
257
+ global_range , local_range = self ._reduction_ranges (
258
+ lowerer , reductionHelperList [0 ]
259
+ )
260
+
261
+ self ._submit_parfor_kernel (
364
262
lowerer ,
365
263
parfor_kernel ,
366
- reductionHelperList [0 ],
264
+ global_range ,
265
+ local_range ,
367
266
)
368
267
369
268
parfor_kernel = create_reduction_remainder_kernel_for_parfor (
@@ -376,9 +275,13 @@ def _reduction_codegen(
376
275
reductionHelperList ,
377
276
)
378
277
379
- self ._submit_reduction_remainder_parfor_kernel (
278
+ global_range , local_range = self ._remainder_ranges (lowerer )
279
+
280
+ self ._submit_parfor_kernel (
380
281
lowerer ,
381
282
parfor_kernel ,
283
+ global_range ,
284
+ local_range ,
382
285
)
383
286
384
287
reductionKernelVar .copy_final_sum_to_host (parfor_kernel )
@@ -492,11 +395,14 @@ def _lower_parfor_as_kernel(self, lowerer, parfor):
492
395
# FIXME: Make the exception more informative
493
396
raise UnsupportedParforError
494
397
398
+ global_range , local_range = self ._loop_ranges (lowerer , loop_ranges )
399
+
495
400
# Finally submit the kernel
496
401
self ._submit_parfor_kernel (
497
402
lowerer ,
498
403
parfor_kernel ,
499
- loop_ranges ,
404
+ global_range ,
405
+ local_range ,
500
406
)
501
407
502
408
# TODO: free the kernel at this point
0 commit comments