@@ -77,16 +77,15 @@ struct scan<Config, BinOp, Exclusive, 1, device_capabilities>
77
77
}
78
78
};
79
79
80
- // 2- level scans
80
+ // do level 0 scans for 2- and 3- level scans (same code)
81
81
template<class Config, class BinOp, class device_capabilities>
82
- struct reduce<Config, BinOp, 2 , device_capabilities>
82
+ struct reduce_level0
83
83
{
84
84
using scalar_t = typename BinOp::type_t;
85
- using vector_lv0_t = vector <scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
86
- using vector_lv1_t = vector <scalar_t, Config::ItemsPerInvocation_1>;
85
+ using vector_t = vector <scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
87
86
88
- template<class DataAccessor, class ScratchAccessor, class Params, typename vector_t >
89
- static void __doLevel0 (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
87
+ template<class DataAccessor, class ScratchAccessor, class Params>
88
+ static void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
90
89
{
91
90
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
92
91
// level 0 scan
@@ -104,7 +103,45 @@ struct reduce<Config, BinOp, 2, device_capabilities>
104
103
}
105
104
}
106
105
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
106
+ };
107
+ };
108
+
109
+ template<class Config, class BinOp, class device_capabilities>
110
+ struct scan_level0
111
+ {
112
+ using scalar_t = typename BinOp::type_t;
113
+ using vector_t = vector <scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
114
+
115
+ template<class DataAccessor, class ScratchAccessor, class Params>
116
+ static void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
117
+ {
118
+ const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
119
+ subgroup2::inclusive_scan<Params> inclusiveScan0;
120
+ // level 0 scan
121
+ [unroll]
122
+ for (uint16_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
123
+ {
124
+ vector_t value;
125
+ dataAccessor.template get<vector_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
126
+ value = inclusiveScan0 (value);
127
+ dataAccessor.template set<vector_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
128
+ if (Config::electLast ())
129
+ {
130
+ const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(uint16_t (glsl::gl_SubgroupID ()), idx);
131
+ scratchAccessor.template set<scalar_t, uint16_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
132
+ }
133
+ }
134
+ scratchAccessor.workgroupExecutionAndMemoryBarrier ();
107
135
}
136
+ };
137
+
138
+ // 2-level scans
139
+ template<class Config, class BinOp, class device_capabilities>
140
+ struct reduce<Config, BinOp, 2 , device_capabilities>
141
+ {
142
+ using scalar_t = typename BinOp::type_t;
143
+ using vector_lv0_t = vector <scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
144
+ using vector_lv1_t = vector <scalar_t, Config::ItemsPerInvocation_1>;
108
145
109
146
template<class DataAccessor, class ScratchAccessor>
110
147
scalar_t __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
@@ -114,7 +151,7 @@ struct reduce<Config, BinOp, 2, device_capabilities>
114
151
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
115
152
BinOp binop;
116
153
117
- __doLevel0< DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t >(dataAccessor, scratchAccessor);
154
+ reduce_level0<Config, BinOp, device_capabilities>::template __call< DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
118
155
119
156
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
120
157
// level 1 scan
@@ -145,28 +182,6 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
145
182
using vector_lv0_t = vector <scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
146
183
using vector_lv1_t = vector <scalar_t, Config::ItemsPerInvocation_1>;
147
184
148
- template<class DataAccessor, class ScratchAccessor, class Params, typename vector_t>
149
- static void __doLevel0 (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
150
- {
151
- const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
152
- subgroup2::inclusive_scan<Params> inclusiveScan0;
153
- // level 0 scan
154
- [unroll]
155
- for (uint16_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
156
- {
157
- vector_t value;
158
- dataAccessor.template get<vector_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
159
- value = inclusiveScan0 (value);
160
- dataAccessor.template set<vector_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
161
- if (Config::electLast ())
162
- {
163
- const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(uint16_t (glsl::gl_SubgroupID ()), idx);
164
- scratchAccessor.template set<scalar_t, uint16_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
165
- }
166
- }
167
- scratchAccessor.workgroupExecutionAndMemoryBarrier ();
168
- }
169
-
170
185
template<class DataAccessor, class ScratchAccessor>
171
186
void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
172
187
{
@@ -175,7 +190,7 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
175
190
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
176
191
BinOp binop;
177
192
178
- __doLevel0< DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t >(dataAccessor, scratchAccessor);
193
+ scan_level0<Config, BinOp, device_capabilities>::template __call< DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
179
194
180
195
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
181
196
// level 1 scan
@@ -243,7 +258,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
243
258
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
244
259
BinOp binop;
245
260
246
- reduce <Config, BinOp, 2 , device_capabilities>::template __doLevel0 <DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t >(dataAccessor, scratchAccessor);
261
+ reduce_level0 <Config, BinOp, device_capabilities>::template __call <DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
247
262
248
263
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
249
264
// level 1 scan
@@ -300,7 +315,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
300
315
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
301
316
BinOp binop;
302
317
303
- scan <Config, BinOp, Exclusive, 2 , device_capabilities>::template __doLevel0 <DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t >(dataAccessor, scratchAccessor);
318
+ scan_level0 <Config, BinOp, device_capabilities>::template __call <DataAccessor, ScratchAccessor, params_lv0_t>(dataAccessor, scratchAccessor);
304
319
305
320
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
306
321
// level 1 scan
0 commit comments