Skip to content

Commit d254ee4

Browse files
author
Diptorup Deb
committed
Refactoring the kernel_launcher.KernelLaunchIRBuilder API.
- Changes to the class constructor to make it easier to use from places other than the parfor_lowerer. Removed the need to pass a lowerer object and instead pass a context and builder. - Adds a new helper function populate_kernel_args_and_args_ty_arrays that populates arrays storing kernel args and kernel arg types.
1 parent 3594cac commit d254ee4

File tree

3 files changed

+90
-68
lines changed

3 files changed

+90
-68
lines changed

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 24 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313

1414
from numba_dpex import config
15+
from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm
1516
from numba_dpex.core.parfors.reduction_helper import (
1617
ReductionHelper,
1718
ReductionKernelVariables,
@@ -26,8 +27,6 @@
2627
create_reduction_remainder_kernel_for_parfor,
2728
)
2829

29-
from numba_dpex.core.datamodel.models import dpex_data_model_manager as dpex_dmm
30-
3130
# A global list of kernels to keep the objects alive indefinitely.
3231
keep_alive_kernels = []
3332

@@ -89,7 +88,9 @@ def _get_exec_queue(self, kernel_fn, lowerer):
8988
"""Creates a stack variable storing the sycl queue pointer used to
9089
launch the kernel function.
9190
"""
92-
self.kernel_builder = KernelLaunchIRBuilder(lowerer, kernel_fn.kernel)
91+
self.kernel_builder = KernelLaunchIRBuilder(
92+
lowerer.context, lowerer.builder, kernel_fn.kernel.addressof_ref()
93+
)
9394

9495
# Create a local variable storing a pointer to a DPCTLSyclQueueRef
9596
# pointer.
@@ -109,71 +110,38 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
109110
AssertionError: If the LLVM IR Value for an argument defined in
110111
Numba IR is not found.
111112
"""
112-
num_flattened_args = 0
113+
self.num_flattened_args = 0
113114

114115
# Compute number of args to be passed to the kernel. Note that the
115116
# actual number of kernel arguments is greater than the count of
116117
# kernel_fn.kernel_args as arrays get flattened.
117118
for arg_type in kernel_fn.kernel_arg_types:
118119
if isinstance(arg_type, DpnpNdArray):
119120
datamodel = dpex_dmm.lookup(arg_type)
120-
num_flattened_args += datamodel.flattened_field_count
121+
self.num_flattened_args += datamodel.flattened_field_count
121122
elif arg_type == types.complex64 or arg_type == types.complex128:
122-
num_flattened_args += 2
123+
self.num_flattened_args += 2
123124
else:
124-
num_flattened_args += 1
125+
self.num_flattened_args += 1
125126

126127
# Create LLVM values for the kernel args list and kernel arg types list
127128
self.args_list = self.kernel_builder.allocate_kernel_arg_array(
128-
num_flattened_args
129+
self.num_flattened_args
129130
)
130131
self.args_ty_list = self.kernel_builder.allocate_kernel_arg_ty_array(
131-
num_flattened_args
132+
self.num_flattened_args
133+
)
134+
callargs_ptrs = []
135+
for arg in kernel_fn.kernel_args:
136+
callargs_ptrs.append(_getvar(lowerer, arg))
137+
138+
self.kernel_builder.populate_kernel_args_and_args_ty_arrays(
139+
kernel_argtys=kernel_fn.kernel_arg_types,
140+
callargs_ptrs=callargs_ptrs,
141+
args_list=self.args_list,
142+
args_ty_list=self.args_ty_list,
143+
datamodel_mgr=dpex_dmm,
132144
)
133-
# Populate the args_list and the args_ty_list LLVM arrays
134-
self.kernel_arg_num = 0
135-
for arg_num, arg in enumerate(kernel_fn.kernel_args):
136-
argtype = kernel_fn.kernel_arg_types[arg_num]
137-
llvm_val = _getvar(lowerer, arg)
138-
if isinstance(argtype, DpnpNdArray):
139-
datamodel = dpex_dmm.lookup(argtype)
140-
self.kernel_builder.build_array_arg(
141-
array_val=llvm_val,
142-
array_data_model=datamodel,
143-
array_rank=argtype.ndim,
144-
arg_list=self.args_list,
145-
args_ty_list=self.args_ty_list,
146-
arg_num=self.kernel_arg_num,
147-
)
148-
self.kernel_arg_num += datamodel.flattened_field_count
149-
else:
150-
if argtype == types.complex64:
151-
self.kernel_builder.build_complex_arg(
152-
llvm_val,
153-
types.float32,
154-
self.args_list,
155-
self.args_ty_list,
156-
self.kernel_arg_num,
157-
)
158-
self.kernel_arg_num += 2
159-
elif argtype == types.complex128:
160-
self.kernel_builder.build_complex_arg(
161-
llvm_val,
162-
types.float64,
163-
self.args_list,
164-
self.args_ty_list,
165-
self.kernel_arg_num,
166-
)
167-
self.kernel_arg_num += 2
168-
else:
169-
self.kernel_builder.build_arg(
170-
llvm_val,
171-
argtype,
172-
self.args_list,
173-
self.args_ty_list,
174-
self.kernel_arg_num,
175-
)
176-
self.kernel_arg_num += 1
177145

178146
def _submit_parfor_kernel(
179147
self,
@@ -213,7 +181,7 @@ def _submit_parfor_kernel(
213181
# Submit a synchronous kernel
214182
self.kernel_builder.submit_sync_kernel(
215183
self.curr_queue,
216-
self.kernel_arg_num,
184+
self.num_flattened_args,
217185
self.args_list,
218186
self.args_ty_list,
219187
global_range,
@@ -255,7 +223,7 @@ def _submit_reduction_main_parfor_kernel(
255223
# Submit a synchronous kernel
256224
self.kernel_builder.submit_sync_kernel(
257225
self.curr_queue,
258-
self.kernel_arg_num,
226+
self.num_flattened_args,
259227
self.args_list,
260228
self.args_ty_list,
261229
global_range,
@@ -290,7 +258,7 @@ def _submit_reduction_remainder_parfor_kernel(
290258
# Submit a synchronous kernel
291259
self.kernel_builder.submit_sync_kernel(
292260
self.curr_queue,
293-
self.kernel_arg_num,
261+
self.num_flattened_args,
294262
self.args_list,
295263
self.args_ty_list,
296264
global_range,

numba_dpex/core/parfors/reduction_helper.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,17 @@ def lowerer(self):
393393
def work_group_size(self):
394394
return self._work_group_size
395395

396-
def copy_final_sum_to_host(self, psrfor_kernel):
396+
def copy_final_sum_to_host(self, parfor_kernel):
397397
lowerer = self.lowerer
398-
ir_builder = KernelLaunchIRBuilder(lowerer, psrfor_kernel.kernel)
398+
ir_builder = KernelLaunchIRBuilder(
399+
lowerer.context,
400+
lowerer.builder,
401+
parfor_kernel.kernel.addressof_ref(),
402+
)
399403

400404
# Create a local variable storing a pointer to a DPCTLSyclQueueRef
401405
# pointer.
402-
curr_queue = ir_builder.get_queue(exec_queue=psrfor_kernel.queue)
406+
curr_queue = ir_builder.get_queue(exec_queue=parfor_kernel.queue)
403407

404408
builder = lowerer.builder
405409
context = lowerer.context

numba_dpex/core/utils/kernel_launcher.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from numba_dpex import utils
88
from numba_dpex.core.runtime.context import DpexRTContext
9+
from numba_dpex.core.types import DpnpNdArray
910
from numba_dpex.dpctl_iface import DpctlCAPIFnBuilder
1011
from numba_dpex.dpctl_iface._helpers import numba_type_to_dpctl_typenum
1112

@@ -19,19 +20,17 @@ class KernelLaunchIRBuilder:
1920
for submitting kernels. The LLVM Values that
2021
"""
2122

22-
def __init__(self, lowerer, kernel):
23+
def __init__(self, context, builder, kernel_addr):
2324
"""Create a KernelLauncher for the specified kernel.
2425
2526
Args:
26-
lowerer: The Numba Lowerer that will be used to generate the code.
27-
kernel: The SYCL kernel for which we are generating the code.
28-
num_inputs: The number of arguments to the kernels.
27+
context: A Numba target context that will be used to generate the code.
28+
builder: An llvmlite IRBuilder instance used to generate LLVM IR.
29+
kernel_addr: The address of a SYCL kernel.
2930
"""
30-
self.lowerer = lowerer
31-
self.context = self.lowerer.context
32-
self.builder = self.lowerer.builder
33-
self.kernel = kernel
34-
self.kernel_addr = self.kernel.addressof_ref()
31+
self.context = context
32+
self.builder = builder
33+
self.kernel_addr = kernel_addr
3534
self.rtctx = DpexRTContext(self.context)
3635

3736
def _build_nullptr(self):
@@ -402,3 +401,54 @@ def submit_sync_kernel(
402401
lr = self._create_sycl_range(local_range)
403402
args = args1 + [lr] + args2
404403
self.rtctx.submit_ndrange(self.builder, *args)
404+
405+
def populate_kernel_args_and_args_ty_arrays(
406+
self,
407+
kernel_argtys,
408+
callargs_ptrs,
409+
args_list,
410+
args_ty_list,
411+
datamodel_mgr,
412+
):
413+
kernel_arg_num = 0
414+
for arg_num, argtype in enumerate(kernel_argtys):
415+
llvm_val = callargs_ptrs[arg_num]
416+
if isinstance(argtype, DpnpNdArray):
417+
datamodel = datamodel_mgr.lookup(argtype)
418+
self.build_array_arg(
419+
array_val=llvm_val,
420+
array_data_model=datamodel,
421+
array_rank=argtype.ndim,
422+
arg_list=args_list,
423+
args_ty_list=args_ty_list,
424+
arg_num=kernel_arg_num,
425+
)
426+
kernel_arg_num += datamodel.flattened_field_count
427+
else:
428+
if argtype == types.complex64:
429+
self.build_complex_arg(
430+
llvm_val,
431+
types.float32,
432+
args_list,
433+
args_ty_list,
434+
kernel_arg_num,
435+
)
436+
kernel_arg_num += 2
437+
elif argtype == types.complex128:
438+
self.build_complex_arg(
439+
llvm_val,
440+
types.float64,
441+
args_list,
442+
args_ty_list,
443+
kernel_arg_num,
444+
)
445+
kernel_arg_num += 2
446+
else:
447+
self.build_arg(
448+
llvm_val,
449+
argtype,
450+
args_list,
451+
args_ty_list,
452+
kernel_arg_num,
453+
)
454+
kernel_arg_num += 1

0 commit comments

Comments
 (0)