1919)
2020from numba .core .typing import signature
2121
22+ from numba_dpex .core .parfors .reduction_helper import ReductionKernelVariables
2223from numba_dpex .core .types import DpctlSyclQueue
2324from numba_dpex .core .types .kernel_api .index_space_ids import NdItemType
25+ from numba_dpex .core .types .kernel_api .local_accessor import LocalAccessorType
2426
2527from .kernel_builder import _print_body # saved for debug
2628from .kernel_builder import (
@@ -41,7 +43,7 @@ def create_reduction_main_kernel_for_parfor(
4143 typemap ,
4244 flags ,
4345 has_aliases ,
44- reductionKernelVar ,
46+ reductionKernelVar : ReductionKernelVariables ,
4547 parfor_reddict = None ,
4648):
4749 """
@@ -79,20 +81,35 @@ def create_reduction_main_kernel_for_parfor(
7981 except KeyError :
8082 pass
8183
84+ parfor_params = reductionKernelVar .parfor_params .copy ()
85+ parfor_legalized_params = reductionKernelVar .parfor_legalized_params .copy ()
86+ parfor_param_types = reductionKernelVar .param_types .copy ()
87+ local_accessors_dict = {}
88+ for k , v in reductionKernelVar .redvars_legal_dict .items ():
89+ la_var = "local_sums_" + v
90+ local_accessors_dict [k ] = la_var
91+ idx = reductionKernelVar .parfor_params .index (k )
92+ arr_ty = reductionKernelVar .param_types [idx ]
93+ la_ty = LocalAccessorType (parfor_dim , arr_ty .dtype )
94+
95+ parfor_params .append (la_var )
96+ parfor_legalized_params .append (la_var )
97+ parfor_param_types .append (la_ty )
98+
8299 kernel_template = TreeReduceIntermediateKernelTemplate (
83100 kernel_name = kernel_name ,
84- kernel_params = reductionKernelVar . parfor_legalized_params ,
101+ kernel_params = parfor_legalized_params ,
85102 ivar_names = reductionKernelVar .legal_loop_indices ,
86103 sentinel_name = sentinel_name ,
87104 loop_ranges = loop_ranges ,
88105 param_dict = reductionKernelVar .param_dict ,
89106 parfor_dim = parfor_dim ,
90107 redvars = reductionKernelVar .parfor_redvars ,
91- parfor_args = reductionKernelVar . parfor_params ,
108+ parfor_args = parfor_params ,
92109 parfor_reddict = parfor_reddict ,
93110 redvars_dict = reductionKernelVar .redvars_legal_dict ,
111+ local_accessors_dict = local_accessors_dict ,
94112 typemap = typemap ,
95- work_group_size = reductionKernelVar .work_group_size ,
96113 )
97114 kernel_ir = kernel_template .kernel_ir
98115
@@ -116,7 +133,7 @@ def create_reduction_main_kernel_for_parfor(
116133 new_var_dict [name ] = mk_unique_var (name )
117134
118135 replace_var_names (kernel_ir .blocks , new_var_dict )
119- kernel_param_types = reductionKernelVar . param_types
136+ kernel_param_types = parfor_param_types
120137 kernel_stub_last_label = max (kernel_ir .blocks .keys ()) + 1
121138 # Add kernel stub last label to each parfor.loop_body label to prevent
122139 # label conflicts.
@@ -164,13 +181,20 @@ def create_reduction_main_kernel_for_parfor(
164181
165182 flags .noalias = old_alias
166183
184+ parfor_params = (
185+ reductionKernelVar .parfor_params .copy ()
186+ + parfor_params [len (reductionKernelVar .parfor_params ) :] # noqa: $203
187+ )
188+
167189 return ParforKernel (
168190 name = kernel_name ,
169191 kernel = sycl_kernel ,
170192 signature = kernel_sig ,
171- kernel_args = reductionKernelVar . parfor_params ,
172- kernel_arg_types = reductionKernelVar . func_arg_types ,
193+ kernel_args = parfor_params ,
194+ kernel_arg_types = parfor_param_types ,
173195 queue = exec_queue ,
196+ local_accessors = set (local_accessors_dict .values ()),
197+ work_group_size = reductionKernelVar .work_group_size ,
174198 )
175199
176200
0 commit comments