3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
5
import copy
6
+ from collections import namedtuple
6
7
7
8
from llvmlite import ir as llvmir
8
- from numba .core import ir , types
9
+ from numba .core import cgutils , ir , types
9
10
from numba .parfors .parfor import (
10
11
find_potential_aliases_parfor ,
11
12
get_parfor_outputs ,
27
28
create_reduction_remainder_kernel_for_parfor ,
28
29
)
29
30
31
+ _KernelArgs = namedtuple (
32
+ "_KernelArgs" ,
33
+ ["num_flattened_args" , "arg_vals" , "arg_types" ],
34
+ )
35
+
36
+
30
37
# A global list of kernels to keep the objects alive indefinitely.
31
38
keep_alive_kernels = []
32
39
@@ -84,21 +91,7 @@ class ParforLowerImpl:
84
91
for a parfor and submits it to a queue.
85
92
"""
86
93
87
- def _get_exec_queue (self , kernel_fn , lowerer ):
88
- """Creates a stack variable storing the sycl queue pointer used to
89
- launch the kernel function.
90
- """
91
- self .kernel_builder = KernelLaunchIRBuilder (
92
- lowerer .context , lowerer .builder , kernel_fn .kernel .addressof_ref ()
93
- )
94
-
95
- # Create a local variable storing a pointer to a DPCTLSyclQueueRef
96
- # pointer.
97
- self .curr_queue = self .kernel_builder .get_queue (
98
- exec_queue = kernel_fn .queue
99
- )
100
-
101
- def _build_kernel_arglist (self , kernel_fn , lowerer ):
94
+ def _build_kernel_arglist (self , kernel_fn , lowerer , kernel_builder ):
102
95
"""Creates local variables for all the arguments and the argument types
103
96
that are passes to the kernel function.
104
97
@@ -110,39 +103,43 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
110
103
AssertionError: If the LLVM IR Value for an argument defined in
111
104
Numba IR is not found.
112
105
"""
113
- self . num_flattened_args = 0
106
+ num_flattened_args = 0
114
107
115
108
# Compute number of args to be passed to the kernel. Note that the
116
109
# actual number of kernel arguments is greater than the count of
117
110
# kernel_fn.kernel_args as arrays get flattened.
118
111
for arg_type in kernel_fn .kernel_arg_types :
119
112
if isinstance (arg_type , DpnpNdArray ):
120
113
datamodel = dpex_dmm .lookup (arg_type )
121
- self . num_flattened_args += datamodel .flattened_field_count
114
+ num_flattened_args += datamodel .flattened_field_count
122
115
elif arg_type == types .complex64 or arg_type == types .complex128 :
123
- self . num_flattened_args += 2
116
+ num_flattened_args += 2
124
117
else :
125
- self . num_flattened_args += 1
118
+ num_flattened_args += 1
126
119
127
120
# Create LLVM values for the kernel args list and kernel arg types list
128
- self .args_list = self .kernel_builder .allocate_kernel_arg_array (
129
- self .num_flattened_args
130
- )
131
- self .args_ty_list = self .kernel_builder .allocate_kernel_arg_ty_array (
132
- self .num_flattened_args
121
+ args_list = kernel_builder .allocate_kernel_arg_array (num_flattened_args )
122
+ args_ty_list = kernel_builder .allocate_kernel_arg_ty_array (
123
+ num_flattened_args
133
124
)
134
125
callargs_ptrs = []
135
126
for arg in kernel_fn .kernel_args :
136
127
callargs_ptrs .append (_getvar (lowerer , arg ))
137
128
138
- self . kernel_builder .populate_kernel_args_and_args_ty_arrays (
129
+ kernel_builder .populate_kernel_args_and_args_ty_arrays (
139
130
kernel_argtys = kernel_fn .kernel_arg_types ,
140
131
callargs_ptrs = callargs_ptrs ,
141
- args_list = self . args_list ,
142
- args_ty_list = self . args_ty_list ,
132
+ args_list = args_list ,
133
+ args_ty_list = args_ty_list ,
143
134
datamodel_mgr = dpex_dmm ,
144
135
)
145
136
137
+ return _KernelArgs (
138
+ num_flattened_args = num_flattened_args ,
139
+ arg_vals = args_list ,
140
+ arg_types = args_ty_list ,
141
+ )
142
+
146
143
def _submit_parfor_kernel (
147
144
self ,
148
145
lowerer ,
@@ -156,9 +153,11 @@ def _submit_parfor_kernel(
156
153
# Ensure that the Python arguments are kept alive for the duration of
157
154
# the kernel execution
158
155
keep_alive_kernels .append (kernel_fn .kernel )
156
+ kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
157
+
158
+ ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
159
+ args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
159
160
160
- self ._get_exec_queue (kernel_fn , lowerer )
161
- self ._build_kernel_arglist (kernel_fn , lowerer )
162
161
# Create a global range over which to submit the kernel based on the
163
162
# loop_ranges of the parfor
164
163
global_range = []
@@ -178,18 +177,26 @@ def _submit_parfor_kernel(
178
177
179
178
local_range = []
180
179
180
+ kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
181
+ kernel_ref = lowerer .builder .inttoptr (
182
+ lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
183
+ cgutils .voidptr_t ,
184
+ )
185
+ curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
186
+
181
187
# Submit a synchronous kernel
182
- self .kernel_builder .submit_sync_kernel (
183
- self .curr_queue ,
184
- self .num_flattened_args ,
185
- self .args_list ,
186
- self .args_ty_list ,
187
- global_range ,
188
- local_range ,
188
+ kernel_builder .submit_sycl_kernel (
189
+ sycl_kernel_ref = kernel_ref ,
190
+ sycl_queue_ref = curr_queue_ref ,
191
+ total_kernel_args = args .num_flattened_args ,
192
+ arg_list = args .arg_vals ,
193
+ arg_ty_list = args .arg_types ,
194
+ global_range = global_range ,
195
+ local_range = local_range ,
189
196
)
190
197
191
198
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
192
- self . kernel_builder .free_queue (sycl_queue_val = self . curr_queue )
199
+ kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
193
200
194
201
def _submit_reduction_main_parfor_kernel (
195
202
self ,
@@ -204,9 +211,11 @@ def _submit_reduction_main_parfor_kernel(
204
211
# Ensure that the Python arguments are kept alive for the duration of
205
212
# the kernel execution
206
213
keep_alive_kernels .append (kernel_fn .kernel )
214
+ kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
215
+
216
+ ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
207
217
208
- self ._get_exec_queue (kernel_fn , lowerer )
209
- self ._build_kernel_arglist (kernel_fn , lowerer )
218
+ args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
210
219
# Create a global range over which to submit the kernel based on the
211
220
# loop_ranges of the parfor
212
221
global_range = []
@@ -220,16 +229,27 @@ def _submit_reduction_main_parfor_kernel(
220
229
_load_range (lowerer , reductionHelper .work_group_size )
221
230
)
222
231
232
+ kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
233
+ kernel_ref = lowerer .builder .inttoptr (
234
+ lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
235
+ cgutils .voidptr_t ,
236
+ )
237
+ curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
238
+
223
239
# Submit a synchronous kernel
224
- self .kernel_builder .submit_sync_kernel (
225
- self .curr_queue ,
226
- self .num_flattened_args ,
227
- self .args_list ,
228
- self .args_ty_list ,
229
- global_range ,
230
- local_range ,
240
+ kernel_builder .submit_sycl_kernel (
241
+ sycl_kernel_ref = kernel_ref ,
242
+ sycl_queue_ref = curr_queue_ref ,
243
+ total_kernel_args = args .num_flattened_args ,
244
+ arg_list = args .arg_vals ,
245
+ arg_ty_list = args .arg_types ,
246
+ global_range = global_range ,
247
+ local_range = local_range ,
231
248
)
232
249
250
+ # At this point we can free the DPCTLSyclQueueRef (curr_queue)
251
+ kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
252
+
233
253
def _submit_reduction_remainder_parfor_kernel (
234
254
self ,
235
255
lowerer ,
@@ -243,8 +263,11 @@ def _submit_reduction_remainder_parfor_kernel(
243
263
# the kernel execution
244
264
keep_alive_kernels .append (kernel_fn .kernel )
245
265
246
- self ._get_exec_queue (kernel_fn , lowerer )
247
- self ._build_kernel_arglist (kernel_fn , lowerer )
266
+ kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
267
+
268
+ ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
269
+
270
+ args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
248
271
# Create a global range over which to submit the kernel based on the
249
272
# loop_ranges of the parfor
250
273
global_range = []
@@ -255,16 +278,27 @@ def _submit_reduction_remainder_parfor_kernel(
255
278
256
279
local_range = []
257
280
281
+ kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
282
+ kernel_ref = lowerer .builder .inttoptr (
283
+ lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
284
+ cgutils .voidptr_t ,
285
+ )
286
+ curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
287
+
258
288
# Submit a synchronous kernel
259
- self .kernel_builder .submit_sync_kernel (
260
- self .curr_queue ,
261
- self .num_flattened_args ,
262
- self .args_list ,
263
- self .args_ty_list ,
264
- global_range ,
265
- local_range ,
289
+ kernel_builder .submit_sycl_kernel (
290
+ sycl_kernel_ref = kernel_ref ,
291
+ sycl_queue_ref = curr_queue_ref ,
292
+ total_kernel_args = args .num_flattened_args ,
293
+ arg_list = args .arg_vals ,
294
+ arg_ty_list = args .arg_types ,
295
+ global_range = global_range ,
296
+ local_range = local_range ,
266
297
)
267
298
299
+ # At this point we can free the DPCTLSyclQueueRef (curr_queue)
300
+ kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
301
+
268
302
def _reduction_codegen (
269
303
self ,
270
304
parfor ,
0 commit comments