Skip to content

Commit e066286

Browse files
author
Diptorup Deb
committed
Change the kernel_launcher.submit_sync_kernel API.
- The helper function was renamed and can now optionally return a DpctlSyclEventRef object to allow waiting at callsite. - Other changes to the API of the kernel_launcher module.
1 parent 7e7ef8b commit e066286

File tree

3 files changed

+127
-86
lines changed

3 files changed

+127
-86
lines changed

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 90 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
import copy
66

7+
from collections import namedtuple
78
from llvmlite import ir as llvmir
8-
from numba.core import ir, types
9+
from numba.core import ir, types, cgutils
910
from numba.parfors.parfor import (
1011
find_potential_aliases_parfor,
1112
get_parfor_outputs,
@@ -27,6 +28,12 @@
2728
create_reduction_remainder_kernel_for_parfor,
2829
)
2930

31+
_KernelArgs = namedtuple(
32+
"_KernelArgs",
33+
["num_flattened_args", "arg_vals", "arg_types"],
34+
)
35+
36+
3037
# A global list of kernels to keep the objects alive indefinitely.
3138
keep_alive_kernels = []
3239

@@ -84,21 +91,7 @@ class ParforLowerImpl:
8491
for a parfor and submits it to a queue.
8592
"""
8693

87-
def _get_exec_queue(self, kernel_fn, lowerer):
88-
"""Creates a stack variable storing the sycl queue pointer used to
89-
launch the kernel function.
90-
"""
91-
self.kernel_builder = KernelLaunchIRBuilder(
92-
lowerer.context, lowerer.builder, kernel_fn.kernel.addressof_ref()
93-
)
94-
95-
# Create a local variable storing a pointer to a DPCTLSyclQueueRef
96-
# pointer.
97-
self.curr_queue = self.kernel_builder.get_queue(
98-
exec_queue=kernel_fn.queue
99-
)
100-
101-
def _build_kernel_arglist(self, kernel_fn, lowerer):
94+
def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder):
10295
"""Creates local variables for all the arguments and the argument types
10396
that are passes to the kernel function.
10497
@@ -110,39 +103,43 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
110103
AssertionError: If the LLVM IR Value for an argument defined in
111104
Numba IR is not found.
112105
"""
113-
self.num_flattened_args = 0
106+
num_flattened_args = 0
114107

115108
# Compute number of args to be passed to the kernel. Note that the
116109
# actual number of kernel arguments is greater than the count of
117110
# kernel_fn.kernel_args as arrays get flattened.
118111
for arg_type in kernel_fn.kernel_arg_types:
119112
if isinstance(arg_type, DpnpNdArray):
120113
datamodel = dpex_dmm.lookup(arg_type)
121-
self.num_flattened_args += datamodel.flattened_field_count
114+
num_flattened_args += datamodel.flattened_field_count
122115
elif arg_type == types.complex64 or arg_type == types.complex128:
123-
self.num_flattened_args += 2
116+
num_flattened_args += 2
124117
else:
125-
self.num_flattened_args += 1
118+
num_flattened_args += 1
126119

127120
# Create LLVM values for the kernel args list and kernel arg types list
128-
self.args_list = self.kernel_builder.allocate_kernel_arg_array(
129-
self.num_flattened_args
130-
)
131-
self.args_ty_list = self.kernel_builder.allocate_kernel_arg_ty_array(
132-
self.num_flattened_args
121+
args_list = kernel_builder.allocate_kernel_arg_array(num_flattened_args)
122+
args_ty_list = kernel_builder.allocate_kernel_arg_ty_array(
123+
num_flattened_args
133124
)
134125
callargs_ptrs = []
135126
for arg in kernel_fn.kernel_args:
136127
callargs_ptrs.append(_getvar(lowerer, arg))
137128

138-
self.kernel_builder.populate_kernel_args_and_args_ty_arrays(
129+
kernel_builder.populate_kernel_args_and_args_ty_arrays(
139130
kernel_argtys=kernel_fn.kernel_arg_types,
140131
callargs_ptrs=callargs_ptrs,
141-
args_list=self.args_list,
142-
args_ty_list=self.args_ty_list,
132+
args_list=args_list,
133+
args_ty_list=args_ty_list,
143134
datamodel_mgr=dpex_dmm,
144135
)
145136

137+
return _KernelArgs(
138+
num_flattened_args=num_flattened_args,
139+
arg_vals=args_list,
140+
arg_types=args_ty_list,
141+
)
142+
146143
def _submit_parfor_kernel(
147144
self,
148145
lowerer,
@@ -156,9 +153,11 @@ def _submit_parfor_kernel(
156153
# Ensure that the Python arguments are kept alive for the duration of
157154
# the kernel execution
158155
keep_alive_kernels.append(kernel_fn.kernel)
156+
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
157+
158+
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
159+
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
159160

160-
self._get_exec_queue(kernel_fn, lowerer)
161-
self._build_kernel_arglist(kernel_fn, lowerer)
162161
# Create a global range over which to submit the kernel based on the
163162
# loop_ranges of the parfor
164163
global_range = []
@@ -178,18 +177,26 @@ def _submit_parfor_kernel(
178177

179178
local_range = []
180179

180+
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
181+
kernel_ref = lowerer.builder.inttoptr(
182+
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
183+
cgutils.voidptr_t,
184+
)
185+
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
186+
181187
# Submit a synchronous kernel
182-
self.kernel_builder.submit_sync_kernel(
183-
self.curr_queue,
184-
self.num_flattened_args,
185-
self.args_list,
186-
self.args_ty_list,
187-
global_range,
188-
local_range,
188+
kernel_builder.submit_sycl_kernel(
189+
sycl_kernel_ref=kernel_ref,
190+
sycl_queue_ref=curr_queue_ref,
191+
total_kernel_args=args.num_flattened_args,
192+
arg_list=args.arg_vals,
193+
arg_ty_list=args.arg_types,
194+
global_range=global_range,
195+
local_range=local_range,
189196
)
190197

191198
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
192-
self.kernel_builder.free_queue(sycl_queue_val=self.curr_queue)
199+
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
193200

194201
def _submit_reduction_main_parfor_kernel(
195202
self,
@@ -204,9 +211,11 @@ def _submit_reduction_main_parfor_kernel(
204211
# Ensure that the Python arguments are kept alive for the duration of
205212
# the kernel execution
206213
keep_alive_kernels.append(kernel_fn.kernel)
214+
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
215+
216+
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
207217

208-
self._get_exec_queue(kernel_fn, lowerer)
209-
self._build_kernel_arglist(kernel_fn, lowerer)
218+
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
210219
# Create a global range over which to submit the kernel based on the
211220
# loop_ranges of the parfor
212221
global_range = []
@@ -220,16 +229,27 @@ def _submit_reduction_main_parfor_kernel(
220229
_load_range(lowerer, reductionHelper.work_group_size)
221230
)
222231

232+
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
233+
kernel_ref = lowerer.builder.inttoptr(
234+
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
235+
cgutils.voidptr_t,
236+
)
237+
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
238+
223239
# Submit a synchronous kernel
224-
self.kernel_builder.submit_sync_kernel(
225-
self.curr_queue,
226-
self.num_flattened_args,
227-
self.args_list,
228-
self.args_ty_list,
229-
global_range,
230-
local_range,
240+
kernel_builder.submit_sycl_kernel(
241+
sycl_kernel_ref=kernel_ref,
242+
sycl_queue_ref=curr_queue_ref,
243+
total_kernel_args=args.num_flattened_args,
244+
arg_list=args.arg_vals,
245+
arg_ty_list=args.arg_types,
246+
global_range=global_range,
247+
local_range=local_range,
231248
)
232249

250+
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
251+
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
252+
233253
def _submit_reduction_remainder_parfor_kernel(
234254
self,
235255
lowerer,
@@ -243,8 +263,11 @@ def _submit_reduction_remainder_parfor_kernel(
243263
# the kernel execution
244264
keep_alive_kernels.append(kernel_fn.kernel)
245265

246-
self._get_exec_queue(kernel_fn, lowerer)
247-
self._build_kernel_arglist(kernel_fn, lowerer)
266+
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
267+
268+
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
269+
270+
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
248271
# Create a global range over which to submit the kernel based on the
249272
# loop_ranges of the parfor
250273
global_range = []
@@ -255,16 +278,27 @@ def _submit_reduction_remainder_parfor_kernel(
255278

256279
local_range = []
257280

281+
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
282+
kernel_ref = lowerer.builder.inttoptr(
283+
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
284+
cgutils.voidptr_t,
285+
)
286+
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
287+
258288
# Submit a synchronous kernel
259-
self.kernel_builder.submit_sync_kernel(
260-
self.curr_queue,
261-
self.num_flattened_args,
262-
self.args_list,
263-
self.args_ty_list,
264-
global_range,
265-
local_range,
289+
kernel_builder.submit_sycl_kernel(
290+
sycl_kernel_ref=kernel_ref,
291+
sycl_queue_ref=curr_queue_ref,
292+
total_kernel_args=args.num_flattened_args,
293+
arg_list=args.arg_vals,
294+
arg_ty_list=args.arg_types,
295+
global_range=global_range,
296+
local_range=local_range,
266297
)
267298

299+
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
300+
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
301+
268302
def _reduction_codegen(
269303
self,
270304
parfor,

numba_dpex/core/parfors/reduction_helper.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,7 @@ def work_group_size(self):
395395

396396
def copy_final_sum_to_host(self, parfor_kernel):
397397
lowerer = self.lowerer
398-
ir_builder = KernelLaunchIRBuilder(
399-
lowerer.context,
400-
lowerer.builder,
401-
parfor_kernel.kernel.addressof_ref(),
402-
)
398+
ir_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
403399

404400
# Create a local variable storing a pointer to a DPCTLSyclQueueRef
405401
# pointer.
@@ -447,4 +443,4 @@ def copy_final_sum_to_host(self, parfor_kernel):
447443
sycl.dpctl_event_wait(builder, event_ref)
448444
sycl.dpctl_event_delete(builder, event_ref)
449445

450-
ir_builder.free_queue(sycl_queue_val=curr_queue)
446+
ir_builder.free_queue(ptr_to_sycl_queue_ref=curr_queue)

0 commit comments

Comments
 (0)