12
12
)
13
13
14
14
from numba_dpex import config
15
+ from numba_dpex .core .datamodel .models import dpex_data_model_manager as dpex_dmm
15
16
from numba_dpex .core .parfors .reduction_helper import (
16
17
ReductionHelper ,
17
18
ReductionKernelVariables ,
26
27
create_reduction_remainder_kernel_for_parfor ,
27
28
)
28
29
29
- from numba_dpex .core .datamodel .models import dpex_data_model_manager as dpex_dmm
30
-
31
30
# A global list of kernels to keep the objects alive indefinitely.
32
31
keep_alive_kernels = []
33
32
@@ -89,7 +88,9 @@ def _get_exec_queue(self, kernel_fn, lowerer):
89
88
"""Creates a stack variable storing the sycl queue pointer used to
90
89
launch the kernel function.
91
90
"""
92
- self .kernel_builder = KernelLaunchIRBuilder (lowerer , kernel_fn .kernel )
91
+ self .kernel_builder = KernelLaunchIRBuilder (
92
+ lowerer .context , lowerer .builder , kernel_fn .kernel .addressof_ref ()
93
+ )
93
94
94
95
# Create a local variable storing a pointer to a DPCTLSyclQueueRef
95
96
# pointer.
@@ -109,71 +110,38 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
109
110
AssertionError: If the LLVM IR Value for an argument defined in
110
111
Numba IR is not found.
111
112
"""
112
- num_flattened_args = 0
113
+ self . num_flattened_args = 0
113
114
114
115
# Compute number of args to be passed to the kernel. Note that the
115
116
# actual number of kernel arguments is greater than the count of
116
117
# kernel_fn.kernel_args as arrays get flattened.
117
118
for arg_type in kernel_fn .kernel_arg_types :
118
119
if isinstance (arg_type , DpnpNdArray ):
119
120
datamodel = dpex_dmm .lookup (arg_type )
120
- num_flattened_args += datamodel .flattened_field_count
121
+ self . num_flattened_args += datamodel .flattened_field_count
121
122
elif arg_type == types .complex64 or arg_type == types .complex128 :
122
- num_flattened_args += 2
123
+ self . num_flattened_args += 2
123
124
else :
124
- num_flattened_args += 1
125
+ self . num_flattened_args += 1
125
126
126
127
# Create LLVM values for the kernel args list and kernel arg types list
127
128
self .args_list = self .kernel_builder .allocate_kernel_arg_array (
128
- num_flattened_args
129
+ self . num_flattened_args
129
130
)
130
131
self .args_ty_list = self .kernel_builder .allocate_kernel_arg_ty_array (
131
- num_flattened_args
132
+ self .num_flattened_args
133
+ )
134
+ callargs_ptrs = []
135
+ for arg in kernel_fn .kernel_args :
136
+ callargs_ptrs .append (_getvar (lowerer , arg ))
137
+
138
+ self .kernel_builder .populate_kernel_args_and_args_ty_arrays (
139
+ kernel_argtys = kernel_fn .kernel_arg_types ,
140
+ callargs_ptrs = callargs_ptrs ,
141
+ args_list = self .args_list ,
142
+ args_ty_list = self .args_ty_list ,
143
+ datamodel_mgr = dpex_dmm ,
132
144
)
133
- # Populate the args_list and the args_ty_list LLVM arrays
134
- self .kernel_arg_num = 0
135
- for arg_num , arg in enumerate (kernel_fn .kernel_args ):
136
- argtype = kernel_fn .kernel_arg_types [arg_num ]
137
- llvm_val = _getvar (lowerer , arg )
138
- if isinstance (argtype , DpnpNdArray ):
139
- datamodel = dpex_dmm .lookup (argtype )
140
- self .kernel_builder .build_array_arg (
141
- array_val = llvm_val ,
142
- array_data_model = datamodel ,
143
- array_rank = argtype .ndim ,
144
- arg_list = self .args_list ,
145
- args_ty_list = self .args_ty_list ,
146
- arg_num = self .kernel_arg_num ,
147
- )
148
- self .kernel_arg_num += datamodel .flattened_field_count
149
- else :
150
- if argtype == types .complex64 :
151
- self .kernel_builder .build_complex_arg (
152
- llvm_val ,
153
- types .float32 ,
154
- self .args_list ,
155
- self .args_ty_list ,
156
- self .kernel_arg_num ,
157
- )
158
- self .kernel_arg_num += 2
159
- elif argtype == types .complex128 :
160
- self .kernel_builder .build_complex_arg (
161
- llvm_val ,
162
- types .float64 ,
163
- self .args_list ,
164
- self .args_ty_list ,
165
- self .kernel_arg_num ,
166
- )
167
- self .kernel_arg_num += 2
168
- else :
169
- self .kernel_builder .build_arg (
170
- llvm_val ,
171
- argtype ,
172
- self .args_list ,
173
- self .args_ty_list ,
174
- self .kernel_arg_num ,
175
- )
176
- self .kernel_arg_num += 1
177
145
178
146
def _submit_parfor_kernel (
179
147
self ,
@@ -213,7 +181,7 @@ def _submit_parfor_kernel(
213
181
# Submit a synchronous kernel
214
182
self .kernel_builder .submit_sync_kernel (
215
183
self .curr_queue ,
216
- self .kernel_arg_num ,
184
+ self .num_flattened_args ,
217
185
self .args_list ,
218
186
self .args_ty_list ,
219
187
global_range ,
@@ -255,7 +223,7 @@ def _submit_reduction_main_parfor_kernel(
255
223
# Submit a synchronous kernel
256
224
self .kernel_builder .submit_sync_kernel (
257
225
self .curr_queue ,
258
- self .kernel_arg_num ,
226
+ self .num_flattened_args ,
259
227
self .args_list ,
260
228
self .args_ty_list ,
261
229
global_range ,
@@ -290,7 +258,7 @@ def _submit_reduction_remainder_parfor_kernel(
290
258
# Submit a synchronous kernel
291
259
self .kernel_builder .submit_sync_kernel (
292
260
self .curr_queue ,
293
- self .kernel_arg_num ,
261
+ self .num_flattened_args ,
294
262
self .args_list ,
295
263
self .args_ty_list ,
296
264
global_range ,
0 commit comments