Skip to content

Commit 2d6ddb8

Browse files
authored
Merge pull request #1243 from IntelPython/feature/apply_builder_pattern_to_kernel_launcher
Apply builder pattern to kernel launcher
2 parents f4b618f + 649ca87 commit 2d6ddb8

File tree

6 files changed

+241
-287
lines changed

6 files changed

+241
-287
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ repos:
5252
- id: pylint
5353
name: pylint
5454
entry: pylint
55-
files: ^numba_dpex/experimental
55+
files: ^numba_dpex/experimental|^numba_dpex/core/utils/kernel_launcher.py
5656
language: system
5757
types: [python]
5858
require_serial: true

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
signature,
4141
kernel_args,
4242
kernel_arg_types,
43-
queue,
43+
queue: dpctl.SyclQueue,
4444
):
4545
self.name = name
4646
self.kernel = kernel
@@ -244,7 +244,7 @@ def create_kernel_for_parfor(
244244
has_aliases,
245245
races,
246246
parfor_outputs,
247-
):
247+
) -> ParforKernel:
248248
"""
249249
Creates a numba_dpex.kernel function for a parfor node.
250250
@@ -422,7 +422,7 @@ def create_kernel_for_parfor(
422422
# arrays are on same device. We can take the queue from the first input
423423
# array and use that to compile the kernel.
424424

425-
exec_queue = None
425+
exec_queue: dpctl.SyclQueue = None
426426

427427
for arg in parfor_args:
428428
obj = typemap[arg]

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 68 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,24 @@
1313
)
1414

1515
from numba_dpex import config
16-
from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm
1716
from numba_dpex.core.parfors.reduction_helper import (
1817
ReductionHelper,
1918
ReductionKernelVariables,
2019
)
2120
from numba_dpex.core.utils.kernel_launcher import KernelLaunchIRBuilder
21+
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
22+
from numba_dpex.core.datamodel.models import (
23+
dpex_data_model_manager as kernel_dmm,
24+
)
2225

2326
from ..exceptions import UnsupportedParforError
2427
from ..types.dpnp_ndarray_type import DpnpNdArray
25-
from .kernel_builder import create_kernel_for_parfor
28+
from .kernel_builder import ParforKernel, create_kernel_for_parfor
2629
from .reduction_kernel_builder import (
2730
create_reduction_main_kernel_for_parfor,
2831
create_reduction_remainder_kernel_for_parfor,
2932
)
3033

31-
_KernelArgs = namedtuple(
32-
"_KernelArgs",
33-
["num_flattened_args", "arg_vals", "arg_types"],
34-
)
35-
3634

3735
# A global list of kernels to keep the objects alive indefinitely.
3836
keep_alive_kernels = []
@@ -68,11 +66,8 @@ def _getvar(lowerer, x):
6866
var_val = lowerer.varmap[x]
6967

7068
if var_val:
71-
if not isinstance(var_val.type, llvmir.PointerType):
72-
with lowerer.builder.goto_entry_block():
73-
var_val_ptr = lowerer.builder.alloca(var_val.type)
74-
lowerer.builder.store(var_val, var_val_ptr)
75-
return var_val_ptr
69+
if isinstance(var_val.type, llvmir.PointerType):
70+
return lowerer.builder.load(var_val)
7671
else:
7772
return var_val
7873
else:
@@ -91,75 +86,15 @@ class ParforLowerImpl:
9186
for a parfor and submits it to a queue.
9287
"""
9388

94-
def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder):
95-
"""Creates local variables for all the arguments and the argument types
96-
that are passes to the kernel function.
97-
98-
Args:
99-
kernel_fn: Kernel function to be launched.
100-
lowerer: The Numba lowerer used to generate the LLVM IR
101-
102-
Raises:
103-
AssertionError: If the LLVM IR Value for an argument defined in
104-
Numba IR is not found.
105-
"""
106-
num_flattened_args = 0
107-
108-
# Compute number of args to be passed to the kernel. Note that the
109-
# actual number of kernel arguments is greater than the count of
110-
# kernel_fn.kernel_args as arrays get flattened.
111-
for arg_type in kernel_fn.kernel_arg_types:
112-
if isinstance(arg_type, DpnpNdArray):
113-
datamodel = dpex_dmm.lookup(arg_type)
114-
num_flattened_args += datamodel.flattened_field_count
115-
elif arg_type == types.complex64 or arg_type == types.complex128:
116-
num_flattened_args += 2
117-
else:
118-
num_flattened_args += 1
119-
120-
# Create LLVM values for the kernel args list and kernel arg types list
121-
args_list = kernel_builder.allocate_kernel_arg_array(num_flattened_args)
122-
args_ty_list = kernel_builder.allocate_kernel_arg_ty_array(
123-
num_flattened_args
124-
)
125-
callargs_ptrs = []
126-
for arg in kernel_fn.kernel_args:
127-
callargs_ptrs.append(_getvar(lowerer, arg))
128-
129-
kernel_builder.populate_kernel_args_and_args_ty_arrays(
130-
kernel_argtys=kernel_fn.kernel_arg_types,
131-
callargs_ptrs=callargs_ptrs,
132-
args_list=args_list,
133-
args_ty_list=args_ty_list,
134-
)
135-
136-
return _KernelArgs(
137-
num_flattened_args=num_flattened_args,
138-
arg_vals=args_list,
139-
arg_types=args_ty_list,
140-
)
141-
142-
def _submit_parfor_kernel(
89+
def _loop_ranges(
14390
self,
14491
lowerer,
145-
kernel_fn,
14692
loop_ranges,
14793
):
148-
"""
149-
Adds a call to submit a kernel function into the function body of the
150-
current Numba JIT compiled function.
151-
"""
152-
# Ensure that the Python arguments are kept alive for the duration of
153-
# the kernel execution
154-
keep_alive_kernels.append(kernel_fn.kernel)
155-
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
156-
157-
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
158-
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
159-
16094
# Create a global range over which to submit the kernel based on the
16195
# loop_ranges of the parfor
16296
global_range = []
97+
16398
# SYCL ranges can have at max 3 dimension. If the parfor is of a higher
16499
# dimension then the indexing for the higher dimensions is done inside
165100
# the kernel.
@@ -173,48 +108,19 @@ def _submit_parfor_kernel(
173108
"non-unit strides are not yet supported."
174109
)
175110
global_range.append(stop)
176-
111+
# For now the local_range is always an empty list as numba_dpex always
112+
# submits kernels generated for parfor nodes as range kernels.
113+
# The provision is kept here if in future there is newer functionality
114+
# to submit these kernels as ndrange.
177115
local_range = []
178116

179-
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
180-
kernel_ref = lowerer.builder.inttoptr(
181-
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
182-
cgutils.voidptr_t,
183-
)
184-
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
185-
186-
# Submit a synchronous kernel
187-
kernel_builder.submit_sycl_kernel(
188-
sycl_kernel_ref=kernel_ref,
189-
sycl_queue_ref=curr_queue_ref,
190-
total_kernel_args=args.num_flattened_args,
191-
arg_list=args.arg_vals,
192-
arg_ty_list=args.arg_types,
193-
global_range=global_range,
194-
local_range=local_range,
195-
)
196-
197-
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
198-
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
117+
return global_range, local_range
199118

200-
def _submit_reduction_main_parfor_kernel(
119+
def _reduction_ranges(
201120
self,
202121
lowerer,
203-
kernel_fn,
204122
reductionHelper=None,
205123
):
206-
"""
207-
Adds a call to submit the main kernel of a parfor reduction into the
208-
function body of the current Numba JIT compiled function.
209-
"""
210-
# Ensure that the Python arguments are kept alive for the duration of
211-
# the kernel execution
212-
keep_alive_kernels.append(kernel_fn.kernel)
213-
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
214-
215-
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
216-
217-
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
218124
# Create a global range over which to submit the kernel based on the
219125
# loop_ranges of the parfor
220126
global_range = []
@@ -228,75 +134,63 @@ def _submit_reduction_main_parfor_kernel(
228134
_load_range(lowerer, reductionHelper.work_group_size)
229135
)
230136

231-
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
232-
kernel_ref = lowerer.builder.inttoptr(
233-
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
234-
cgutils.voidptr_t,
235-
)
236-
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
237-
238-
# Submit a synchronous kernel
239-
kernel_builder.submit_sycl_kernel(
240-
sycl_kernel_ref=kernel_ref,
241-
sycl_queue_ref=curr_queue_ref,
242-
total_kernel_args=args.num_flattened_args,
243-
arg_list=args.arg_vals,
244-
arg_ty_list=args.arg_types,
245-
global_range=global_range,
246-
local_range=local_range,
247-
)
137+
return global_range, local_range
248138

249-
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
250-
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
139+
def _remainder_ranges(self, lowerer):
140+
# Create a global range over which to submit the kernel based on the
141+
# loop_ranges of the parfor
142+
global_range = []
251143

252-
def _submit_reduction_remainder_parfor_kernel(
144+
stop = _load_range(lowerer, 1)
145+
146+
global_range.append(stop)
147+
148+
local_range = []
149+
150+
return global_range, local_range
151+
152+
def _submit_parfor_kernel(
253153
self,
254154
lowerer,
255-
kernel_fn,
155+
kernel_fn: ParforKernel,
156+
global_range,
157+
local_range,
256158
):
257159
"""
258-
Adds a call to submit the remainder kernel of a parfor reduction into
259-
the function body of the current Numba JIT compiled function.
160+
Adds a call to submit a kernel function into the function body of the
161+
current Numba JIT compiled function.
260162
"""
261163
# Ensure that the Python arguments are kept alive for the duration of
262164
# the kernel execution
263165
keep_alive_kernels.append(kernel_fn.kernel)
166+
kl_builder = KernelLaunchIRBuilder(
167+
lowerer.context, lowerer.builder, kernel_dmm
168+
)
264169

265-
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
266-
267-
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
268-
269-
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
270-
# Create a global range over which to submit the kernel based on the
271-
# loop_ranges of the parfor
272-
global_range = []
273-
274-
stop = _load_range(lowerer, 1)
170+
queue_ref = kl_builder.get_queue(exec_queue=kernel_fn.queue)
275171

276-
global_range.append(stop)
277-
278-
local_range = []
172+
kernel_args = []
173+
for arg in kernel_fn.kernel_args:
174+
kernel_args.append(_getvar(lowerer, arg))
279175

280176
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
281177
kernel_ref = lowerer.builder.inttoptr(
282178
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
283179
cgutils.voidptr_t,
284180
)
285-
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
286-
287-
# Submit a synchronous kernel
288-
kernel_builder.submit_sycl_kernel(
289-
sycl_kernel_ref=kernel_ref,
290-
sycl_queue_ref=curr_queue_ref,
291-
total_kernel_args=args.num_flattened_args,
292-
arg_list=args.arg_vals,
293-
arg_ty_list=args.arg_types,
294-
global_range=global_range,
295-
local_range=local_range,
181+
182+
kl_builder.set_kernel(kernel_ref)
183+
kl_builder.set_queue(queue_ref)
184+
kl_builder.set_range(global_range, local_range)
185+
kl_builder.set_arguments(
186+
kernel_fn.kernel_arg_types, kernel_args=kernel_args
296187
)
188+
kl_builder.set_dependant_event_list(dep_events=[])
189+
event_ref = kl_builder.submit()
297190

298-
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
299-
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
191+
sycl.dpctl_event_wait(lowerer.builder, event_ref)
192+
sycl.dpctl_event_delete(lowerer.builder, event_ref)
193+
sycl.dpctl_queue_delete(lowerer.builder, queue_ref)
300194

301195
def _reduction_codegen(
302196
self,
@@ -360,10 +254,15 @@ def _reduction_codegen(
360254
parfor_reddict,
361255
)
362256

363-
self._submit_reduction_main_parfor_kernel(
257+
global_range, local_range = self._reduction_ranges(
258+
lowerer, reductionHelperList[0]
259+
)
260+
261+
self._submit_parfor_kernel(
364262
lowerer,
365263
parfor_kernel,
366-
reductionHelperList[0],
264+
global_range,
265+
local_range,
367266
)
368267

369268
parfor_kernel = create_reduction_remainder_kernel_for_parfor(
@@ -376,9 +275,13 @@ def _reduction_codegen(
376275
reductionHelperList,
377276
)
378277

379-
self._submit_reduction_remainder_parfor_kernel(
278+
global_range, local_range = self._remainder_ranges(lowerer)
279+
280+
self._submit_parfor_kernel(
380281
lowerer,
381282
parfor_kernel,
283+
global_range,
284+
local_range,
382285
)
383286

384287
reductionKernelVar.copy_final_sum_to_host(parfor_kernel)
@@ -492,11 +395,14 @@ def _lower_parfor_as_kernel(self, lowerer, parfor):
492395
# FIXME: Make the exception more informative
493396
raise UnsupportedParforError
494397

398+
global_range, local_range = self._loop_ranges(lowerer, loop_ranges)
399+
495400
# Finally submit the kernel
496401
self._submit_parfor_kernel(
497402
lowerer,
498403
parfor_kernel,
499-
loop_ranges,
404+
global_range,
405+
local_range,
500406
)
501407

502408
# TODO: free the kernel at this point

0 commit comments

Comments
 (0)