13
13
)
14
14
15
15
from numba_dpex import config
16
- from numba_dpex .core .datamodel .models import dpex_data_model_manager as dpex_dmm
17
16
from numba_dpex .core .parfors .reduction_helper import (
18
17
ReductionHelper ,
19
18
ReductionKernelVariables ,
20
19
)
21
20
from numba_dpex .core .utils .kernel_launcher import KernelLaunchIRBuilder
21
+ from numba_dpex .dpctl_iface import libsyclinterface_bindings as sycl
22
+ from numba_dpex .core .datamodel .models import (
23
+ dpex_data_model_manager as kernel_dmm ,
24
+ )
22
25
23
26
from ..exceptions import UnsupportedParforError
24
27
from ..types .dpnp_ndarray_type import DpnpNdArray
28
31
create_reduction_remainder_kernel_for_parfor ,
29
32
)
30
33
31
- _KernelArgs = namedtuple (
32
- "_KernelArgs" ,
33
- ["num_flattened_args" , "arg_vals" , "arg_types" ],
34
- )
35
-
36
34
37
35
# A global list of kernels to keep the objects alive indefinitely.
38
36
keep_alive_kernels = []
@@ -68,11 +66,8 @@ def _getvar(lowerer, x):
68
66
var_val = lowerer .varmap [x ]
69
67
70
68
if var_val :
71
- if not isinstance (var_val .type , llvmir .PointerType ):
72
- with lowerer .builder .goto_entry_block ():
73
- var_val_ptr = lowerer .builder .alloca (var_val .type )
74
- lowerer .builder .store (var_val , var_val_ptr )
75
- return var_val_ptr
69
+ if isinstance (var_val .type , llvmir .PointerType ):
70
+ return lowerer .builder .load (var_val )
76
71
else :
77
72
return var_val
78
73
else :
@@ -91,56 +86,6 @@ class ParforLowerImpl:
91
86
for a parfor and submits it to a queue.
92
87
"""
93
88
94
- def _build_kernel_arglist (
95
- self , kernel_fn , lowerer , kernel_builder : KernelLaunchIRBuilder
96
- ):
97
- """Creates local variables for all the arguments and the argument types
98
- that are passes to the kernel function.
99
-
100
- Args:
101
- kernel_fn: Kernel function to be launched.
102
- lowerer: The Numba lowerer used to generate the LLVM IR
103
-
104
- Raises:
105
- AssertionError: If the LLVM IR Value for an argument defined in
106
- Numba IR is not found.
107
- """
108
- num_flattened_args = 0
109
-
110
- # Compute number of args to be passed to the kernel. Note that the
111
- # actual number of kernel arguments is greater than the count of
112
- # kernel_fn.kernel_args as arrays get flattened.
113
- for arg_type in kernel_fn .kernel_arg_types :
114
- if isinstance (arg_type , DpnpNdArray ):
115
- datamodel = dpex_dmm .lookup (arg_type )
116
- num_flattened_args += datamodel .flattened_field_count
117
- elif arg_type == types .complex64 or arg_type == types .complex128 :
118
- num_flattened_args += 2
119
- else :
120
- num_flattened_args += 1
121
-
122
- # Create LLVM values for the kernel args list and kernel arg types list
123
- args_list = kernel_builder .allocate_kernel_arg_array (num_flattened_args )
124
- args_ty_list = kernel_builder .allocate_kernel_arg_ty_array (
125
- num_flattened_args
126
- )
127
- callargs_ptrs = []
128
- for arg in kernel_fn .kernel_args :
129
- callargs_ptrs .append (_getvar (lowerer , arg ))
130
-
131
- kernel_builder .populate_kernel_args_and_args_ty_arrays (
132
- kernel_argtys = kernel_fn .kernel_arg_types ,
133
- callargs_ptrs = callargs_ptrs ,
134
- args_list = args_list ,
135
- args_ty_list = args_ty_list ,
136
- )
137
-
138
- return _KernelArgs (
139
- num_flattened_args = num_flattened_args ,
140
- arg_vals = args_list ,
141
- arg_types = args_ty_list ,
142
- )
143
-
144
89
def _loop_ranges (
145
90
self ,
146
91
lowerer ,
@@ -163,7 +108,10 @@ def _loop_ranges(
163
108
"non-unit strides are not yet supported."
164
109
)
165
110
global_range .append (stop )
166
-
111
+ # For now the local_range is always an empty list as numba_dpex always
112
+ # submits kernels generated for parfor nodes as range kernels.
113
+ # The provision is kept here if in future there is newer functionality
114
+ # to submit these kernels as ndrange.
167
115
local_range = []
168
116
169
117
return global_range , local_range
@@ -215,31 +163,34 @@ def _submit_parfor_kernel(
215
163
# Ensure that the Python arguments are kept alive for the duration of
216
164
# the kernel execution
217
165
keep_alive_kernels .append (kernel_fn .kernel )
218
- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
166
+ kl_builder = KernelLaunchIRBuilder (
167
+ lowerer .context , lowerer .builder , kernel_dmm
168
+ )
169
+
170
+ queue_ref = kl_builder .get_queue (exec_queue = kernel_fn .queue )
219
171
220
- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
221
- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
172
+ kernel_args = []
173
+ for arg in kernel_fn .kernel_args :
174
+ kernel_args .append (_getvar (lowerer , arg ))
222
175
223
176
kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
224
177
kernel_ref = lowerer .builder .inttoptr (
225
178
lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
226
179
cgutils .voidptr_t ,
227
180
)
228
- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
229
-
230
- # Submit a synchronous kernel
231
- kernel_builder .submit_sycl_kernel (
232
- sycl_kernel_ref = kernel_ref ,
233
- sycl_queue_ref = curr_queue_ref ,
234
- total_kernel_args = args .num_flattened_args ,
235
- arg_list = args .arg_vals ,
236
- arg_ty_list = args .arg_types ,
237
- global_range = global_range ,
238
- local_range = local_range ,
181
+
182
+ kl_builder .set_kernel (kernel_ref )
183
+ kl_builder .set_queue (queue_ref )
184
+ kl_builder .set_range (global_range , local_range )
185
+ kl_builder .set_arguments (
186
+ kernel_fn .kernel_arg_types , kernel_args = kernel_args
239
187
)
188
+ kl_builder .set_dependant_event_list (dep_events = [])
189
+ event_ref = kl_builder .submit ()
240
190
241
- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
242
- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
191
+ sycl .dpctl_event_wait (lowerer .builder , event_ref )
192
+ sycl .dpctl_event_delete (lowerer .builder , event_ref )
193
+ sycl .dpctl_queue_delete (lowerer .builder , queue_ref )
243
194
244
195
def _reduction_codegen (
245
196
self ,
0 commit comments