Skip to content

Commit 9f8d1ac

Browse files
committed
Add nd_item to reduction template
1 parent 5fb6093 commit 9f8d1ac

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

numba_dpex/core/parfors/kernel_templates/reduction_template.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _generate_kernel_stub_as_string(self):
5555

5656
gufunc_txt = ""
5757
gufunc_txt += "def " + self._kernel_name
58-
gufunc_txt += "(" + (", ".join(self._kernel_params)) + "):\n"
58+
gufunc_txt += "(nd_item, " + (", ".join(self._kernel_params)) + "):\n"
5959
global_id_dim = 0
6060
for_loop_dim = self._parfor_dim
6161

@@ -64,14 +64,17 @@ def _generate_kernel_stub_as_string(self):
6464
else:
6565
global_id_dim = self._parfor_dim
6666

67+
gufunc_txt += " group = nd_item.get_group()\n"
6768
for dim in range(global_id_dim):
6869
dstr = str(dim)
6970
gufunc_txt += (
70-
f" {self._ivar_names[dim]} = dpex.get_global_id({dstr})\n"
71+
f" {self._ivar_names[dim]} = nd_item.get_global_id({dstr})\n"
7172
)
72-
gufunc_txt += f" local_id{dim} = dpex.get_local_id({dstr})\n"
73-
gufunc_txt += f" local_size{dim} = dpex.get_local_size({dstr})\n"
74-
gufunc_txt += f" group_id{dim} = dpex.get_group_id({dstr})\n"
73+
gufunc_txt += f" local_id{dim} = nd_item.get_local_id({dstr})\n"
74+
gufunc_txt += (
75+
f" local_size{dim} = group.get_local_range({dstr})\n"
76+
)
77+
gufunc_txt += f" group_id{dim} = group.get_group_id({dstr})\n"
7578

7679
# Allocate local_sums arrays for each reduction variable.
7780
for redvar in self._redvars:

numba_dpex/core/parfors/reduction_kernel_builder.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from numba.core.typing import signature
2121

2222
from numba_dpex.core.types import DpctlSyclQueue
23+
from numba_dpex.core.types.kernel_api.index_space_ids import NdItemType
2324

2425
from .kernel_builder import _print_body # saved for debug
2526
from .kernel_builder import (
@@ -48,6 +49,7 @@ def create_reduction_main_kernel_for_parfor(
4849
"""
4950

5051
loc = parfor_node.init_block.loc
52+
parfor_dim = len(parfor_node.loop_nests)
5153

5254
for race in parfor_node.races:
5355
msg = (
@@ -84,7 +86,7 @@ def create_reduction_main_kernel_for_parfor(
8486
sentinel_name=sentinel_name,
8587
loop_ranges=loop_ranges,
8688
param_dict=reductionKernelVar.param_dict,
87-
parfor_dim=len(parfor_node.loop_nests),
89+
parfor_dim=parfor_dim,
8890
redvars=reductionKernelVar.parfor_redvars,
8991
parfor_args=reductionKernelVar.parfor_params,
9092
parfor_reddict=parfor_reddict,
@@ -136,6 +138,13 @@ def create_reduction_main_kernel_for_parfor(
136138
if not has_aliases:
137139
flags.noalias = True
138140

141+
# The first argument to a range kernel is a kernel_api.NdItem object. The
142+
# ``NdItem`` object is used by the kernel_api.spirv backend to generate the
143+
# correct SPIR-V indexing instructions. Since, the argument is not something
144+
# available originally in the kernel_param_types, we add it at this point to
145+
# make sure the kernel signature matches the actual generated code.
146+
ty_item = NdItemType(parfor_dim)
147+
kernel_param_types = (ty_item, *kernel_param_types)
139148
kernel_sig = signature(types.none, *kernel_param_types)
140149

141150
# FIXME: A better design is required so that we do not have to create a

0 commit comments

Comments
 (0)