1919)
2020from numba .core .typing import signature
2121
22+ from numba_dpex .core .parfors .reduction_helper import ReductionKernelVariables
2223from numba_dpex .core .types import DpctlSyclQueue
24+ from numba_dpex .core .types .kernel_api .index_space_ids import NdItemType
25+ from numba_dpex .core .types .kernel_api .local_accessor import LocalAccessorType
2326
2427from .kernel_builder import _print_body # saved for debug
2528from .kernel_builder import (
@@ -40,14 +43,15 @@ def create_reduction_main_kernel_for_parfor(
4043 typemap ,
4144 flags ,
4245 has_aliases ,
43- reductionKernelVar ,
46+ reductionKernelVar : ReductionKernelVariables ,
4447 parfor_reddict = None ,
4548):
4649 """
4750 Creates a numba_dpex.kernel function for reduction main kernel.
4851 """
4952
5053 loc = parfor_node .init_block .loc
54+ parfor_dim = len (parfor_node .loop_nests )
5155
5256 for race in parfor_node .races :
5357 msg = (
@@ -77,20 +81,35 @@ def create_reduction_main_kernel_for_parfor(
7781 except KeyError :
7882 pass
7983
84+ parfor_params = reductionKernelVar .parfor_params .copy ()
85+ parfor_legalized_params = reductionKernelVar .parfor_legalized_params .copy ()
86+ parfor_param_types = reductionKernelVar .param_types .copy ()
87+ local_accessors_dict = {}
88+ for k , v in reductionKernelVar .redvars_legal_dict .items ():
89+ la_var = "local_sums_" + v
90+ local_accessors_dict [k ] = la_var
91+ idx = reductionKernelVar .parfor_params .index (k )
92+ arr_ty = reductionKernelVar .param_types [idx ]
93+ la_ty = LocalAccessorType (parfor_dim , arr_ty .dtype )
94+
95+ parfor_params .append (la_var )
96+ parfor_legalized_params .append (la_var )
97+ parfor_param_types .append (la_ty )
98+
8099 kernel_template = TreeReduceIntermediateKernelTemplate (
81100 kernel_name = kernel_name ,
82- kernel_params = reductionKernelVar . parfor_legalized_params ,
101+ kernel_params = parfor_legalized_params ,
83102 ivar_names = reductionKernelVar .legal_loop_indices ,
84103 sentinel_name = sentinel_name ,
85104 loop_ranges = loop_ranges ,
86105 param_dict = reductionKernelVar .param_dict ,
87- parfor_dim = len ( parfor_node . loop_nests ) ,
106+ parfor_dim = parfor_dim ,
88107 redvars = reductionKernelVar .parfor_redvars ,
89- parfor_args = reductionKernelVar . parfor_params ,
108+ parfor_args = parfor_params ,
90109 parfor_reddict = parfor_reddict ,
91110 redvars_dict = reductionKernelVar .redvars_legal_dict ,
111+ local_accessors_dict = local_accessors_dict ,
92112 typemap = typemap ,
93- work_group_size = reductionKernelVar .work_group_size ,
94113 )
95114 kernel_ir = kernel_template .kernel_ir
96115
@@ -114,7 +133,7 @@ def create_reduction_main_kernel_for_parfor(
114133 new_var_dict [name ] = mk_unique_var (name )
115134
116135 replace_var_names (kernel_ir .blocks , new_var_dict )
117- kernel_param_types = reductionKernelVar . param_types
136+ kernel_param_types = parfor_param_types
118137 kernel_stub_last_label = max (kernel_ir .blocks .keys ()) + 1
119138 # Add kernel stub last label to each parfor.loop_body label to prevent
120139 # label conflicts.
@@ -136,6 +155,13 @@ def create_reduction_main_kernel_for_parfor(
136155 if not has_aliases :
137156 flags .noalias = True
138157
158+ # The first argument to a range kernel is a kernel_api.NdItem object. The
159+ # ``NdItem`` object is used by the kernel_api.spirv backend to generate the
160+ # correct SPIR-V indexing instructions. Since, the argument is not something
161+ # available originally in the kernel_param_types, we add it at this point to
162+ # make sure the kernel signature matches the actual generated code.
163+ ty_item = NdItemType (parfor_dim )
164+ kernel_param_types = (ty_item , * kernel_param_types )
139165 kernel_sig = signature (types .none , * kernel_param_types )
140166
141167 # FIXME: A better design is required so that we do not have to create a
@@ -155,13 +181,20 @@ def create_reduction_main_kernel_for_parfor(
155181
156182 flags .noalias = old_alias
157183
184+ parfor_params = (
185+ reductionKernelVar .parfor_params .copy ()
186+ + parfor_params [len (reductionKernelVar .parfor_params ) :] # noqa: $203
187+ )
188+
158189 return ParforKernel (
159190 name = kernel_name ,
160191 kernel = sycl_kernel ,
161192 signature = kernel_sig ,
162- kernel_args = reductionKernelVar . parfor_params ,
163- kernel_arg_types = reductionKernelVar . func_arg_types ,
193+ kernel_args = parfor_params ,
194+ kernel_arg_types = parfor_param_types ,
164195 queue = exec_queue ,
196+ local_accessors = set (local_accessors_dict .values ()),
197+ work_group_size = reductionKernelVar .work_group_size ,
165198 )
166199
167200
0 commit comments