Skip to content

Commit e52b423

Browse files
authored
Merge pull request #1424 from IntelPython/feature/migrate_parfor
Feature/migrate parfor
2 parents 5fb6093 + 40c465b commit e52b423

File tree

4 files changed

+79
-42
lines changed

4 files changed

+79
-42
lines changed

numba_dpex/core/parfors/kernel_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,17 @@ def __init__(
4444
kernel_args,
4545
kernel_arg_types,
4646
queue: dpctl.SyclQueue,
47+
local_accessors=None,
48+
work_group_size=None,
4749
):
4850
self.name = name
4951
self.kernel = kernel
5052
self.signature = signature
5153
self.kernel_args = kernel_args
5254
self.kernel_arg_types = kernel_arg_types
5355
self.queue = queue
56+
self.local_accessors = local_accessors
57+
self.work_group_size = work_group_size
5458

5559

5660
def _print_block(block):

numba_dpex/core/parfors/kernel_templates/reduction_template.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def __init__(
3030
parfor_args,
3131
parfor_reddict,
3232
redvars_dict,
33+
local_accessors_dict,
3334
typemap,
34-
work_group_size,
3535
) -> None:
3636
self._kernel_name = kernel_name
3737
self._kernel_params = kernel_params
@@ -44,8 +44,8 @@ def __init__(
4444
self._parfor_args = parfor_args
4545
self._parfor_reddict = parfor_reddict
4646
self._redvars_dict = redvars_dict
47+
self._local_accessors_dict = local_accessors_dict
4748
self._typemap = typemap
48-
self._work_group_size = work_group_size
4949

5050
self._kernel_txt = self._generate_kernel_stub_as_string()
5151
self._kernel_ir = self._generate_kernel_ir()
@@ -55,7 +55,7 @@ def _generate_kernel_stub_as_string(self):
5555

5656
gufunc_txt = ""
5757
gufunc_txt += "def " + self._kernel_name
58-
gufunc_txt += "(" + (", ".join(self._kernel_params)) + "):\n"
58+
gufunc_txt += "(nd_item, " + (", ".join(self._kernel_params)) + "):\n"
5959
global_id_dim = 0
6060
for_loop_dim = self._parfor_dim
6161

@@ -64,21 +64,17 @@ def _generate_kernel_stub_as_string(self):
6464
else:
6565
global_id_dim = self._parfor_dim
6666

67+
gufunc_txt += " group = nd_item.get_group()\n"
6768
for dim in range(global_id_dim):
6869
dstr = str(dim)
6970
gufunc_txt += (
70-
f" {self._ivar_names[dim]} = dpex.get_global_id({dstr})\n"
71+
f" {self._ivar_names[dim]} = nd_item.get_global_id({dstr})\n"
7172
)
72-
gufunc_txt += f" local_id{dim} = dpex.get_local_id({dstr})\n"
73-
gufunc_txt += f" local_size{dim} = dpex.get_local_size({dstr})\n"
74-
gufunc_txt += f" group_id{dim} = dpex.get_group_id({dstr})\n"
75-
76-
# Allocate local_sums arrays for each reduction variable.
77-
for redvar in self._redvars:
78-
rtyp = str(self._typemap[redvar])
79-
redvar = self._redvars_dict[redvar]
80-
gufunc_txt += f" local_sums_{redvar} = \
81-
dpex.local.array({self._work_group_size}, dpnp.{rtyp})\n"
73+
gufunc_txt += f" local_id{dim} = nd_item.get_local_id({dstr})\n"
74+
gufunc_txt += (
75+
f" local_size{dim} = group.get_local_range({dstr})\n"
76+
)
77+
gufunc_txt += f" group_id{dim} = group.get_group_id({dstr})\n"
8278

8379
for dim in range(global_id_dim, for_loop_dim):
8480
for indent in range(1 + (dim - global_id_dim)):
@@ -282,10 +278,13 @@ def _generate_kernel_stub_as_string(self):
282278
)
283279

284280
for redvar in self._redvars:
281+
rtyp = str(self._typemap[redvar])
285282
legal_redvar = self._redvars_dict[redvar]
286283
gufunc_txt += " "
287284
gufunc_txt += legal_redvar + " = "
288-
gufunc_txt += f"{self._parfor_reddict[redvar].init_val}\n"
285+
gufunc_txt += (
286+
f"dpnp.{rtyp}({self._parfor_reddict[redvar].init_val})\n"
287+
)
289288

290289
gufunc_txt += (
291290
" "
@@ -294,32 +293,17 @@ def _generate_kernel_stub_as_string(self):
294293
+ f"{self._global_size_var_name[0]} + j\n"
295294
)
296295

297-
for redvar in self._redvars:
298-
rtyp = str(self._typemap[redvar])
299-
redvar = self._redvars_dict[redvar]
300-
gufunc_txt += (
301-
" "
302-
+ f"local_sums_{redvar} = "
303-
+ f"dpex.local.array(1, dpnp.{rtyp})\n"
304-
)
305-
306296
gufunc_txt += " " + self._sentinel_name + " = 0\n"
307297

308-
for i, redvar in enumerate(self._redvars):
309-
legal_redvar = self._redvars_dict[redvar]
310-
gufunc_txt += (
311-
" " + f"local_sums_{legal_redvar}[0] = {legal_redvar}\n"
312-
)
313-
314298
for i, redvar in enumerate(self._redvars):
315299
legal_redvar = self._redvars_dict[redvar]
316300
redop = self._parfor_reddict[redvar].redop
317301
if redop == operator.iadd:
318302
gufunc_txt += f" {self._final_sum_var_name[i]}[0] += \
319-
local_sums_{legal_redvar}[0]\n"
303+
{legal_redvar}\n"
320304
elif redop == operator.imul:
321305
gufunc_txt += f" {self._final_sum_var_name[i]}[0] *= \
322-
local_sums_{legal_redvar}[0]\n"
306+
{legal_redvar}\n"
323307
else:
324308
raise NotImplementedError
325309

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,24 @@ def _submit_parfor_kernel(
169169
queue_ref = kl_builder.get_queue(exec_queue=kernel_fn.queue)
170170

171171
kernel_args = []
172-
for arg in kernel_fn.kernel_args:
173-
kernel_args.append(_getvar(lowerer, arg))
172+
for i, arg in enumerate(kernel_fn.kernel_args):
173+
if (
174+
kernel_fn.local_accessors is not None
175+
and arg in kernel_fn.local_accessors
176+
):
177+
wg_size = lowerer.context.get_constant(
178+
types.intp, kernel_fn.work_group_size
179+
)
180+
la_shape = cgutils.pack_array(lowerer.builder, [wg_size])
181+
arg_ty = kernel_fn.kernel_arg_types[i]
182+
la = cgutils.create_struct_proxy(arg_ty)(
183+
lowerer.context,
184+
lowerer.builder,
185+
)
186+
la.shape = la_shape
187+
kernel_args.append(la._getvalue())
188+
else:
189+
kernel_args.append(_getvar(lowerer, arg))
174190

175191
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
176192
kernel_ref = lowerer.builder.inttoptr(

numba_dpex/core/parfors/reduction_kernel_builder.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
)
2020
from numba.core.typing import signature
2121

22+
from numba_dpex.core.parfors.reduction_helper import ReductionKernelVariables
2223
from numba_dpex.core.types import DpctlSyclQueue
24+
from numba_dpex.core.types.kernel_api.index_space_ids import NdItemType
25+
from numba_dpex.core.types.kernel_api.local_accessor import LocalAccessorType
2326

2427
from .kernel_builder import _print_body # saved for debug
2528
from .kernel_builder import (
@@ -40,14 +43,15 @@ def create_reduction_main_kernel_for_parfor(
4043
typemap,
4144
flags,
4245
has_aliases,
43-
reductionKernelVar,
46+
reductionKernelVar: ReductionKernelVariables,
4447
parfor_reddict=None,
4548
):
4649
"""
4750
Creates a numba_dpex.kernel function for reduction main kernel.
4851
"""
4952

5053
loc = parfor_node.init_block.loc
54+
parfor_dim = len(parfor_node.loop_nests)
5155

5256
for race in parfor_node.races:
5357
msg = (
@@ -77,20 +81,35 @@ def create_reduction_main_kernel_for_parfor(
7781
except KeyError:
7882
pass
7983

84+
parfor_params = reductionKernelVar.parfor_params.copy()
85+
parfor_legalized_params = reductionKernelVar.parfor_legalized_params.copy()
86+
parfor_param_types = reductionKernelVar.param_types.copy()
87+
local_accessors_dict = {}
88+
for k, v in reductionKernelVar.redvars_legal_dict.items():
89+
la_var = "local_sums_" + v
90+
local_accessors_dict[k] = la_var
91+
idx = reductionKernelVar.parfor_params.index(k)
92+
arr_ty = reductionKernelVar.param_types[idx]
93+
la_ty = LocalAccessorType(parfor_dim, arr_ty.dtype)
94+
95+
parfor_params.append(la_var)
96+
parfor_legalized_params.append(la_var)
97+
parfor_param_types.append(la_ty)
98+
8099
kernel_template = TreeReduceIntermediateKernelTemplate(
81100
kernel_name=kernel_name,
82-
kernel_params=reductionKernelVar.parfor_legalized_params,
101+
kernel_params=parfor_legalized_params,
83102
ivar_names=reductionKernelVar.legal_loop_indices,
84103
sentinel_name=sentinel_name,
85104
loop_ranges=loop_ranges,
86105
param_dict=reductionKernelVar.param_dict,
87-
parfor_dim=len(parfor_node.loop_nests),
106+
parfor_dim=parfor_dim,
88107
redvars=reductionKernelVar.parfor_redvars,
89-
parfor_args=reductionKernelVar.parfor_params,
108+
parfor_args=parfor_params,
90109
parfor_reddict=parfor_reddict,
91110
redvars_dict=reductionKernelVar.redvars_legal_dict,
111+
local_accessors_dict=local_accessors_dict,
92112
typemap=typemap,
93-
work_group_size=reductionKernelVar.work_group_size,
94113
)
95114
kernel_ir = kernel_template.kernel_ir
96115

@@ -114,7 +133,7 @@ def create_reduction_main_kernel_for_parfor(
114133
new_var_dict[name] = mk_unique_var(name)
115134

116135
replace_var_names(kernel_ir.blocks, new_var_dict)
117-
kernel_param_types = reductionKernelVar.param_types
136+
kernel_param_types = parfor_param_types
118137
kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1
119138
# Add kernel stub last label to each parfor.loop_body label to prevent
120139
# label conflicts.
@@ -136,6 +155,13 @@ def create_reduction_main_kernel_for_parfor(
136155
if not has_aliases:
137156
flags.noalias = True
138157

158+
# The first argument to a range kernel is a kernel_api.NdItem object. The
159+
# ``NdItem`` object is used by the kernel_api.spirv backend to generate the
160+
# correct SPIR-V indexing instructions. Since, the argument is not something
161+
# available originally in the kernel_param_types, we add it at this point to
162+
# make sure the kernel signature matches the actual generated code.
163+
ty_item = NdItemType(parfor_dim)
164+
kernel_param_types = (ty_item, *kernel_param_types)
139165
kernel_sig = signature(types.none, *kernel_param_types)
140166

141167
# FIXME: A better design is required so that we do not have to create a
@@ -155,13 +181,20 @@ def create_reduction_main_kernel_for_parfor(
155181

156182
flags.noalias = old_alias
157183

184+
parfor_params = (
185+
reductionKernelVar.parfor_params.copy()
186+
+ parfor_params[len(reductionKernelVar.parfor_params) :] # noqa: $203
187+
)
188+
158189
return ParforKernel(
159190
name=kernel_name,
160191
kernel=sycl_kernel,
161192
signature=kernel_sig,
162-
kernel_args=reductionKernelVar.parfor_params,
163-
kernel_arg_types=reductionKernelVar.func_arg_types,
193+
kernel_args=parfor_params,
194+
kernel_arg_types=parfor_param_types,
164195
queue=exec_queue,
196+
local_accessors=set(local_accessors_dict.values()),
197+
work_group_size=reductionKernelVar.work_group_size,
165198
)
166199

167200

0 commit comments

Comments
 (0)