Skip to content

Commit d5d48c0

Browse files
committed
Migrate parfor local accessor to new api
1 parent 9f8d1ac commit d5d48c0

File tree

4 files changed

+55
-18
lines changed

4 files changed

+55
-18
lines changed

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@ def __init__(
4444
kernel_args,
4545
kernel_arg_types,
4646
queue: dpctl.SyclQueue,
47+
local_accessors=None,
48+
work_group_size=None,
4749
):
4850
self.name = name
4951
self.kernel = kernel
5052
self.signature = signature
5153
self.kernel_args = kernel_args
5254
self.kernel_arg_types = kernel_arg_types
5355
self.queue = queue
56+
self.local_accessors = local_accessors
57+
self.work_group_size = work_group_size
5458

5559

5660
def _print_block(block):

numba_dpex/core/parfors/kernel_templates/reduction_template.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def __init__(
3030
parfor_args,
3131
parfor_reddict,
3232
redvars_dict,
33+
local_accessors_dict,
3334
typemap,
34-
work_group_size,
3535
) -> None:
3636
self._kernel_name = kernel_name
3737
self._kernel_params = kernel_params
@@ -44,8 +44,8 @@ def __init__(
4444
self._parfor_args = parfor_args
4545
self._parfor_reddict = parfor_reddict
4646
self._redvars_dict = redvars_dict
47+
self._local_accessors_dict = local_accessors_dict
4748
self._typemap = typemap
48-
self._work_group_size = work_group_size
4949

5050
self._kernel_txt = self._generate_kernel_stub_as_string()
5151
self._kernel_ir = self._generate_kernel_ir()
@@ -76,13 +76,6 @@ def _generate_kernel_stub_as_string(self):
7676
)
7777
gufunc_txt += f" group_id{dim} = group.get_group_id({dstr})\n"
7878

79-
# Allocate local_sums arrays for each reduction variable.
80-
for redvar in self._redvars:
81-
rtyp = str(self._typemap[redvar])
82-
redvar = self._redvars_dict[redvar]
83-
gufunc_txt += f" local_sums_{redvar} = \
84-
dpex.local.array({self._work_group_size}, dpnp.{rtyp})\n"
85-
8679
for dim in range(global_id_dim, for_loop_dim):
8780
for indent in range(1 + (dim - global_id_dim)):
8881
gufunc_txt += " "

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,24 @@ def _submit_parfor_kernel(
169169
queue_ref = kl_builder.get_queue(exec_queue=kernel_fn.queue)
170170

171171
kernel_args = []
172-
for arg in kernel_fn.kernel_args:
173-
kernel_args.append(_getvar(lowerer, arg))
172+
for i, arg in enumerate(kernel_fn.kernel_args):
173+
if (
174+
kernel_fn.local_accessors is not None
175+
and arg in kernel_fn.local_accessors
176+
):
177+
wg_size = lowerer.context.get_constant(
178+
types.intp, kernel_fn.work_group_size
179+
)
180+
la_shape = cgutils.pack_array(lowerer.builder, [wg_size])
181+
arg_ty = kernel_fn.kernel_arg_types[i]
182+
la = cgutils.create_struct_proxy(arg_ty)(
183+
lowerer.context,
184+
lowerer.builder,
185+
)
186+
la.shape = la_shape
187+
kernel_args.append(la._getvalue())
188+
else:
189+
kernel_args.append(_getvar(lowerer, arg))
174190

175191
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
176192
kernel_ref = lowerer.builder.inttoptr(

numba_dpex/core/parfors/reduction_kernel_builder.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
)
2020
from numba.core.typing import signature
2121

22+
from numba_dpex.core.parfors.reduction_helper import ReductionKernelVariables
2223
from numba_dpex.core.types import DpctlSyclQueue
2324
from numba_dpex.core.types.kernel_api.index_space_ids import NdItemType
25+
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
2426

2527
from .kernel_builder import _print_body # saved for debug
2628
from .kernel_builder import (
@@ -41,7 +43,7 @@ def create_reduction_main_kernel_for_parfor(
4143
typemap,
4244
flags,
4345
has_aliases,
44-
reductionKernelVar,
46+
reductionKernelVar: ReductionKernelVariables,
4547
parfor_reddict=None,
4648
):
4749
"""
@@ -79,20 +81,35 @@ def create_reduction_main_kernel_for_parfor(
7981
except KeyError:
8082
pass
8183

84+
parfor_params = reductionKernelVar.parfor_params.copy()
85+
parfor_legalized_params = reductionKernelVar.parfor_legalized_params.copy()
86+
parfor_param_types = reductionKernelVar.param_types.copy()
87+
local_accessors_dict = {}
88+
for k, v in reductionKernelVar.redvars_legal_dict.items():
89+
la_var = "local_sums_" + v
90+
local_accessors_dict[k] = la_var
91+
idx = reductionKernelVar.parfor_params.index(k)
92+
arr_ty = reductionKernelVar.param_types[idx]
93+
la_ty = LocalAccessorType(parfor_dim, arr_ty.dtype)
94+
95+
parfor_params.append(la_var)
96+
parfor_legalized_params.append(la_var)
97+
parfor_param_types.append(la_ty)
98+
8299
kernel_template = TreeReduceIntermediateKernelTemplate(
83100
kernel_name=kernel_name,
84-
kernel_params=reductionKernelVar.parfor_legalized_params,
101+
kernel_params=parfor_legalized_params,
85102
ivar_names=reductionKernelVar.legal_loop_indices,
86103
sentinel_name=sentinel_name,
87104
loop_ranges=loop_ranges,
88105
param_dict=reductionKernelVar.param_dict,
89106
parfor_dim=parfor_dim,
90107
redvars=reductionKernelVar.parfor_redvars,
91-
parfor_args=reductionKernelVar.parfor_params,
108+
parfor_args=parfor_params,
92109
parfor_reddict=parfor_reddict,
93110
redvars_dict=reductionKernelVar.redvars_legal_dict,
111+
local_accessors_dict=local_accessors_dict,
94112
typemap=typemap,
95-
work_group_size=reductionKernelVar.work_group_size,
96113
)
97114
kernel_ir = kernel_template.kernel_ir
98115

@@ -116,7 +133,7 @@ def create_reduction_main_kernel_for_parfor(
116133
new_var_dict[name] = mk_unique_var(name)
117134

118135
replace_var_names(kernel_ir.blocks, new_var_dict)
119-
kernel_param_types = reductionKernelVar.param_types
136+
kernel_param_types = parfor_param_types
120137
kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1
121138
# Add kernel stub last label to each parfor.loop_body label to prevent
122139
# label conflicts.
@@ -164,13 +181,20 @@ def create_reduction_main_kernel_for_parfor(
164181

165182
flags.noalias = old_alias
166183

184+
parfor_params = (
185+
reductionKernelVar.parfor_params.copy()
186+
+ parfor_params[len(reductionKernelVar.parfor_params) :] # noqa: $203
187+
)
188+
167189
return ParforKernel(
168190
name=kernel_name,
169191
kernel=sycl_kernel,
170192
signature=kernel_sig,
171-
kernel_args=reductionKernelVar.parfor_params,
172-
kernel_arg_types=reductionKernelVar.func_arg_types,
193+
kernel_args=parfor_params,
194+
kernel_arg_types=parfor_param_types,
173195
queue=exec_queue,
196+
local_accessors=set(local_accessors_dict.values()),
197+
work_group_size=reductionKernelVar.work_group_size,
174198
)
175199

176200

0 commit comments

Comments
 (0)