19
19
)
20
20
from numba .core .typing import signature
21
21
22
+ from numba_dpex .core .parfors .reduction_helper import ReductionKernelVariables
22
23
from numba_dpex .core .types import DpctlSyclQueue
23
24
from numba_dpex .core .types .kernel_api .index_space_ids import NdItemType
25
+ from numba_dpex .core .types .kernel_api .local_accessor import LocalAccessorType
24
26
25
27
from .kernel_builder import _print_body # saved for debug
26
28
from .kernel_builder import (
@@ -41,7 +43,7 @@ def create_reduction_main_kernel_for_parfor(
41
43
typemap ,
42
44
flags ,
43
45
has_aliases ,
44
- reductionKernelVar ,
46
+ reductionKernelVar : ReductionKernelVariables ,
45
47
parfor_reddict = None ,
46
48
):
47
49
"""
@@ -79,20 +81,35 @@ def create_reduction_main_kernel_for_parfor(
79
81
except KeyError :
80
82
pass
81
83
84
+ parfor_params = reductionKernelVar .parfor_params .copy ()
85
+ parfor_legalized_params = reductionKernelVar .parfor_legalized_params .copy ()
86
+ parfor_param_types = reductionKernelVar .param_types .copy ()
87
+ local_accessors_dict = {}
88
+ for k , v in reductionKernelVar .redvars_legal_dict .items ():
89
+ la_var = "local_sums_" + v
90
+ local_accessors_dict [k ] = la_var
91
+ idx = reductionKernelVar .parfor_params .index (k )
92
+ arr_ty = reductionKernelVar .param_types [idx ]
93
+ la_ty = LocalAccessorType (parfor_dim , arr_ty .dtype )
94
+
95
+ parfor_params .append (la_var )
96
+ parfor_legalized_params .append (la_var )
97
+ parfor_param_types .append (la_ty )
98
+
82
99
kernel_template = TreeReduceIntermediateKernelTemplate (
83
100
kernel_name = kernel_name ,
84
- kernel_params = reductionKernelVar . parfor_legalized_params ,
101
+ kernel_params = parfor_legalized_params ,
85
102
ivar_names = reductionKernelVar .legal_loop_indices ,
86
103
sentinel_name = sentinel_name ,
87
104
loop_ranges = loop_ranges ,
88
105
param_dict = reductionKernelVar .param_dict ,
89
106
parfor_dim = parfor_dim ,
90
107
redvars = reductionKernelVar .parfor_redvars ,
91
- parfor_args = reductionKernelVar . parfor_params ,
108
+ parfor_args = parfor_params ,
92
109
parfor_reddict = parfor_reddict ,
93
110
redvars_dict = reductionKernelVar .redvars_legal_dict ,
111
+ local_accessors_dict = local_accessors_dict ,
94
112
typemap = typemap ,
95
- work_group_size = reductionKernelVar .work_group_size ,
96
113
)
97
114
kernel_ir = kernel_template .kernel_ir
98
115
@@ -116,7 +133,7 @@ def create_reduction_main_kernel_for_parfor(
116
133
new_var_dict [name ] = mk_unique_var (name )
117
134
118
135
replace_var_names (kernel_ir .blocks , new_var_dict )
119
- kernel_param_types = reductionKernelVar . param_types
136
+ kernel_param_types = parfor_param_types
120
137
kernel_stub_last_label = max (kernel_ir .blocks .keys ()) + 1
121
138
# Add kernel stub last label to each parfor.loop_body label to prevent
122
139
# label conflicts.
@@ -164,13 +181,20 @@ def create_reduction_main_kernel_for_parfor(
164
181
165
182
flags .noalias = old_alias
166
183
184
+ parfor_params = (
185
+ reductionKernelVar .parfor_params .copy ()
186
+ + parfor_params [len (reductionKernelVar .parfor_params ) :] # noqa: $203
187
+ )
188
+
167
189
return ParforKernel (
168
190
name = kernel_name ,
169
191
kernel = sycl_kernel ,
170
192
signature = kernel_sig ,
171
- kernel_args = reductionKernelVar . parfor_params ,
172
- kernel_arg_types = reductionKernelVar . func_arg_types ,
193
+ kernel_args = parfor_params ,
194
+ kernel_arg_types = parfor_param_types ,
173
195
queue = exec_queue ,
196
+ local_accessors = set (local_accessors_dict .values ()),
197
+ work_group_size = reductionKernelVar .work_group_size ,
174
198
)
175
199
176
200
0 commit comments