19
19
)
20
20
from numba .core .typing import signature
21
21
22
+ from numba_dpex .core .parfors .reduction_helper import ReductionKernelVariables
22
23
from numba_dpex .core .types import DpctlSyclQueue
24
+ from numba_dpex .core .types .kernel_api .index_space_ids import NdItemType
25
+ from numba_dpex .core .types .kernel_api .local_accessor import LocalAccessorType
23
26
24
27
from .kernel_builder import _print_body # saved for debug
25
28
from .kernel_builder import (
@@ -40,14 +43,15 @@ def create_reduction_main_kernel_for_parfor(
40
43
typemap ,
41
44
flags ,
42
45
has_aliases ,
43
- reductionKernelVar ,
46
+ reductionKernelVar : ReductionKernelVariables ,
44
47
parfor_reddict = None ,
45
48
):
46
49
"""
47
50
Creates a numba_dpex.kernel function for reduction main kernel.
48
51
"""
49
52
50
53
loc = parfor_node .init_block .loc
54
+ parfor_dim = len (parfor_node .loop_nests )
51
55
52
56
for race in parfor_node .races :
53
57
msg = (
@@ -77,20 +81,35 @@ def create_reduction_main_kernel_for_parfor(
77
81
except KeyError :
78
82
pass
79
83
84
+ parfor_params = reductionKernelVar .parfor_params .copy ()
85
+ parfor_legalized_params = reductionKernelVar .parfor_legalized_params .copy ()
86
+ parfor_param_types = reductionKernelVar .param_types .copy ()
87
+ local_accessors_dict = {}
88
+ for k , v in reductionKernelVar .redvars_legal_dict .items ():
89
+ la_var = "local_sums_" + v
90
+ local_accessors_dict [k ] = la_var
91
+ idx = reductionKernelVar .parfor_params .index (k )
92
+ arr_ty = reductionKernelVar .param_types [idx ]
93
+ la_ty = LocalAccessorType (parfor_dim , arr_ty .dtype )
94
+
95
+ parfor_params .append (la_var )
96
+ parfor_legalized_params .append (la_var )
97
+ parfor_param_types .append (la_ty )
98
+
80
99
kernel_template = TreeReduceIntermediateKernelTemplate (
81
100
kernel_name = kernel_name ,
82
- kernel_params = reductionKernelVar . parfor_legalized_params ,
101
+ kernel_params = parfor_legalized_params ,
83
102
ivar_names = reductionKernelVar .legal_loop_indices ,
84
103
sentinel_name = sentinel_name ,
85
104
loop_ranges = loop_ranges ,
86
105
param_dict = reductionKernelVar .param_dict ,
87
- parfor_dim = len ( parfor_node . loop_nests ) ,
106
+ parfor_dim = parfor_dim ,
88
107
redvars = reductionKernelVar .parfor_redvars ,
89
- parfor_args = reductionKernelVar . parfor_params ,
108
+ parfor_args = parfor_params ,
90
109
parfor_reddict = parfor_reddict ,
91
110
redvars_dict = reductionKernelVar .redvars_legal_dict ,
111
+ local_accessors_dict = local_accessors_dict ,
92
112
typemap = typemap ,
93
- work_group_size = reductionKernelVar .work_group_size ,
94
113
)
95
114
kernel_ir = kernel_template .kernel_ir
96
115
@@ -114,7 +133,7 @@ def create_reduction_main_kernel_for_parfor(
114
133
new_var_dict [name ] = mk_unique_var (name )
115
134
116
135
replace_var_names (kernel_ir .blocks , new_var_dict )
117
- kernel_param_types = reductionKernelVar . param_types
136
+ kernel_param_types = parfor_param_types
118
137
kernel_stub_last_label = max (kernel_ir .blocks .keys ()) + 1
119
138
# Add kernel stub last label to each parfor.loop_body label to prevent
120
139
# label conflicts.
@@ -136,6 +155,13 @@ def create_reduction_main_kernel_for_parfor(
136
155
if not has_aliases :
137
156
flags .noalias = True
138
157
158
+ # The first argument to a range kernel is a kernel_api.NdItem object. The
159
+ # ``NdItem`` object is used by the kernel_api.spirv backend to generate the
160
+ # correct SPIR-V indexing instructions. Since, the argument is not something
161
+ # available originally in the kernel_param_types, we add it at this point to
162
+ # make sure the kernel signature matches the actual generated code.
163
+ ty_item = NdItemType (parfor_dim )
164
+ kernel_param_types = (ty_item , * kernel_param_types )
139
165
kernel_sig = signature (types .none , * kernel_param_types )
140
166
141
167
# FIXME: A better design is required so that we do not have to create a
@@ -155,13 +181,20 @@ def create_reduction_main_kernel_for_parfor(
155
181
156
182
flags .noalias = old_alias
157
183
184
+ parfor_params = (
185
+ reductionKernelVar .parfor_params .copy ()
186
+ + parfor_params [len (reductionKernelVar .parfor_params ) :] # noqa: $203
187
+ )
188
+
158
189
return ParforKernel (
159
190
name = kernel_name ,
160
191
kernel = sycl_kernel ,
161
192
signature = kernel_sig ,
162
- kernel_args = reductionKernelVar . parfor_params ,
163
- kernel_arg_types = reductionKernelVar . func_arg_types ,
193
+ kernel_args = parfor_params ,
194
+ kernel_arg_types = parfor_param_types ,
164
195
queue = exec_queue ,
196
+ local_accessors = set (local_accessors_dict .values ()),
197
+ work_group_size = reductionKernelVar .work_group_size ,
165
198
)
166
199
167
200
0 commit comments