Skip to content

Commit 649ca87

Browse files
committed
Apply builder pattern to kernel launcher
1 parent 23e78a5 commit 649ca87

File tree

6 files changed

+202
-203
lines changed

6 files changed

+202
-203
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ repos:
5252
- id: pylint
5353
name: pylint
5454
entry: pylint
55-
files: ^numba_dpex/experimental
55+
files: ^numba_dpex/experimental|^numba_dpex/core/utils/kernel_launcher.py
5656
language: system
5757
types: [python]
5858
require_serial: true

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
signature,
4141
kernel_args,
4242
kernel_arg_types,
43-
queue,
43+
queue: dpctl.SyclQueue,
4444
):
4545
self.name = name
4646
self.kernel = kernel
@@ -244,7 +244,7 @@ def create_kernel_for_parfor(
244244
has_aliases,
245245
races,
246246
parfor_outputs,
247-
):
247+
) -> ParforKernel:
248248
"""
249249
Creates a numba_dpex.kernel function for a parfor node.
250250
@@ -422,7 +422,7 @@ def create_kernel_for_parfor(
422422
# arrays are on same device. We can take the queue from the first input
423423
# array and use that to compile the kernel.
424424

425-
exec_queue = None
425+
exec_queue: dpctl.SyclQueue = None
426426

427427
for arg in parfor_args:
428428
obj = typemap[arg]

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 29 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
)
1414

1515
from numba_dpex import config
16-
from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm
1716
from numba_dpex.core.parfors.reduction_helper import (
1817
ReductionHelper,
1918
ReductionKernelVariables,
2019
)
2120
from numba_dpex.core.utils.kernel_launcher import KernelLaunchIRBuilder
21+
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
22+
from numba_dpex.core.datamodel.models import (
23+
dpex_data_model_manager as kernel_dmm,
24+
)
2225

2326
from ..exceptions import UnsupportedParforError
2427
from ..types.dpnp_ndarray_type import DpnpNdArray
@@ -28,11 +31,6 @@
2831
create_reduction_remainder_kernel_for_parfor,
2932
)
3033

31-
_KernelArgs = namedtuple(
32-
"_KernelArgs",
33-
["num_flattened_args", "arg_vals", "arg_types"],
34-
)
35-
3634

3735
# A global list of kernels to keep the objects alive indefinitely.
3836
keep_alive_kernels = []
@@ -68,11 +66,8 @@ def _getvar(lowerer, x):
6866
var_val = lowerer.varmap[x]
6967

7068
if var_val:
71-
if not isinstance(var_val.type, llvmir.PointerType):
72-
with lowerer.builder.goto_entry_block():
73-
var_val_ptr = lowerer.builder.alloca(var_val.type)
74-
lowerer.builder.store(var_val, var_val_ptr)
75-
return var_val_ptr
69+
if isinstance(var_val.type, llvmir.PointerType):
70+
return lowerer.builder.load(var_val)
7671
else:
7772
return var_val
7873
else:
@@ -91,56 +86,6 @@ class ParforLowerImpl:
9186
for a parfor and submits it to a queue.
9287
"""
9388

94-
def _build_kernel_arglist(
95-
self, kernel_fn, lowerer, kernel_builder: KernelLaunchIRBuilder
96-
):
97-
"""Creates local variables for all the arguments and the argument types
98-
that are passes to the kernel function.
99-
100-
Args:
101-
kernel_fn: Kernel function to be launched.
102-
lowerer: The Numba lowerer used to generate the LLVM IR
103-
104-
Raises:
105-
AssertionError: If the LLVM IR Value for an argument defined in
106-
Numba IR is not found.
107-
"""
108-
num_flattened_args = 0
109-
110-
# Compute number of args to be passed to the kernel. Note that the
111-
# actual number of kernel arguments is greater than the count of
112-
# kernel_fn.kernel_args as arrays get flattened.
113-
for arg_type in kernel_fn.kernel_arg_types:
114-
if isinstance(arg_type, DpnpNdArray):
115-
datamodel = dpex_dmm.lookup(arg_type)
116-
num_flattened_args += datamodel.flattened_field_count
117-
elif arg_type == types.complex64 or arg_type == types.complex128:
118-
num_flattened_args += 2
119-
else:
120-
num_flattened_args += 1
121-
122-
# Create LLVM values for the kernel args list and kernel arg types list
123-
args_list = kernel_builder.allocate_kernel_arg_array(num_flattened_args)
124-
args_ty_list = kernel_builder.allocate_kernel_arg_ty_array(
125-
num_flattened_args
126-
)
127-
callargs_ptrs = []
128-
for arg in kernel_fn.kernel_args:
129-
callargs_ptrs.append(_getvar(lowerer, arg))
130-
131-
kernel_builder.populate_kernel_args_and_args_ty_arrays(
132-
kernel_argtys=kernel_fn.kernel_arg_types,
133-
callargs_ptrs=callargs_ptrs,
134-
args_list=args_list,
135-
args_ty_list=args_ty_list,
136-
)
137-
138-
return _KernelArgs(
139-
num_flattened_args=num_flattened_args,
140-
arg_vals=args_list,
141-
arg_types=args_ty_list,
142-
)
143-
14489
def _loop_ranges(
14590
self,
14691
lowerer,
@@ -163,7 +108,10 @@ def _loop_ranges(
163108
"non-unit strides are not yet supported."
164109
)
165110
global_range.append(stop)
166-
111+
# For now the local_range is always an empty list as numba_dpex always
112+
# submits kernels generated for parfor nodes as range kernels.
113+
# The provision is kept here if in future there is newer functionality
114+
# to submit these kernels as ndrange.
167115
local_range = []
168116

169117
return global_range, local_range
@@ -215,31 +163,34 @@ def _submit_parfor_kernel(
215163
# Ensure that the Python arguments are kept alive for the duration of
216164
# the kernel execution
217165
keep_alive_kernels.append(kernel_fn.kernel)
218-
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
166+
kl_builder = KernelLaunchIRBuilder(
167+
lowerer.context, lowerer.builder, kernel_dmm
168+
)
169+
170+
queue_ref = kl_builder.get_queue(exec_queue=kernel_fn.queue)
219171

220-
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
221-
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
172+
kernel_args = []
173+
for arg in kernel_fn.kernel_args:
174+
kernel_args.append(_getvar(lowerer, arg))
222175

223176
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
224177
kernel_ref = lowerer.builder.inttoptr(
225178
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
226179
cgutils.voidptr_t,
227180
)
228-
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
229-
230-
# Submit a synchronous kernel
231-
kernel_builder.submit_sycl_kernel(
232-
sycl_kernel_ref=kernel_ref,
233-
sycl_queue_ref=curr_queue_ref,
234-
total_kernel_args=args.num_flattened_args,
235-
arg_list=args.arg_vals,
236-
arg_ty_list=args.arg_types,
237-
global_range=global_range,
238-
local_range=local_range,
181+
182+
kl_builder.set_kernel(kernel_ref)
183+
kl_builder.set_queue(queue_ref)
184+
kl_builder.set_range(global_range, local_range)
185+
kl_builder.set_arguments(
186+
kernel_fn.kernel_arg_types, kernel_args=kernel_args
239187
)
188+
kl_builder.set_dependant_event_list(dep_events=[])
189+
event_ref = kl_builder.submit()
240190

241-
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
242-
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
191+
sycl.dpctl_event_wait(lowerer.builder, event_ref)
192+
sycl.dpctl_event_delete(lowerer.builder, event_ref)
193+
sycl.dpctl_queue_delete(lowerer.builder, queue_ref)
243194

244195
def _reduction_codegen(
245196
self,

numba_dpex/core/parfors/reduction_helper.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from numba.parfors.parfor_lowering_utils import ParforLoweringBuilder
1818

1919
from numba_dpex import utils
20+
from numba_dpex.core.datamodel.models import (
21+
dpex_data_model_manager as kernel_dmm,
22+
)
2023
from numba_dpex.core.utils.kernel_launcher import KernelLaunchIRBuilder
2124
from numba_dpex.dpctl_iface import libsyclinterface_bindings as sycl
2225

@@ -395,11 +398,13 @@ def work_group_size(self):
395398

396399
def copy_final_sum_to_host(self, parfor_kernel):
397400
lowerer = self.lowerer
398-
ir_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
401+
kl_builder = KernelLaunchIRBuilder(
402+
lowerer.context, lowerer.builder, kernel_dmm
403+
)
399404

400405
# Create a local variable storing a pointer to a DPCTLSyclQueueRef
401406
# pointer.
402-
curr_queue = ir_builder.get_queue(exec_queue=parfor_kernel.queue)
407+
queue_ref = kl_builder.get_queue(exec_queue=parfor_kernel.queue)
403408

404409
builder = lowerer.builder
405410
context = lowerer.context
@@ -433,7 +438,7 @@ def copy_final_sum_to_host(self, parfor_kernel):
433438
)
434439

435440
args = [
436-
builder.load(curr_queue),
441+
queue_ref,
437442
dest,
438443
src,
439444
builder.load(item_size),
@@ -443,4 +448,4 @@ def copy_final_sum_to_host(self, parfor_kernel):
443448
sycl.dpctl_event_wait(builder, event_ref)
444449
sycl.dpctl_event_delete(builder, event_ref)
445450

446-
ir_builder.free_queue(ptr_to_sycl_queue_ref=curr_queue)
451+
sycl.dpctl_queue_delete(builder, queue_ref)

0 commit comments

Comments
 (0)