@@ -85,22 +85,17 @@ struct reduce<Config, BinOp, 2, device_capabilities>
85
85
using vector_lv0_t = vector <scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
86
86
using vector_lv1_t = vector <scalar_t, Config::ItemsPerInvocation_1>;
87
87
88
- template<class DataAccessor, class ScratchAccessor>
89
- scalar_t __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
88
+ template<class DataAccessor, class ScratchAccessor, class Params, typename vector_t >
89
+ static void __doLevel0 (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
90
90
{
91
- using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
92
- using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
93
- using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
94
- BinOp binop;
95
-
96
91
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
97
92
// level 0 scan
98
- subgroup2::reduction<params_lv0_t > reduction0;
93
+ subgroup2::reduction<Params > reduction0;
99
94
[unroll]
100
95
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
101
96
{
102
- vector_lv0_t scan_local;
103
- dataAccessor.template get<vector_lv0_t , uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local);
97
+ vector_t scan_local;
98
+ dataAccessor.template get<vector_t , uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local);
104
99
scan_local = reduction0 (scan_local);
105
100
if (Config::electLast ())
106
101
{
@@ -109,7 +104,19 @@ struct reduce<Config, BinOp, 2, device_capabilities>
109
104
}
110
105
}
111
106
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
107
+ }
108
+
109
+ template<class DataAccessor, class ScratchAccessor>
110
+ scalar_t __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
111
+ {
112
+ using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
113
+ using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
114
+ using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
115
+ BinOp binop;
112
116
117
+ __doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
118
+
119
+ const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
113
120
// level 1 scan
114
121
subgroup2::reduction<params_lv1_t> reduction1;
115
122
if (glsl::gl_SubgroupID () == 0 )
@@ -138,32 +145,39 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
138
145
using vector_lv0_t = vector <scalar_t, Config::ItemsPerInvocation_0>; // data accessor needs to be this type
139
146
using vector_lv1_t = vector <scalar_t, Config::ItemsPerInvocation_1>;
140
147
141
- template<class DataAccessor, class ScratchAccessor>
142
- void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
148
+ template<class DataAccessor, class ScratchAccessor, class Params, typename vector_t >
149
+ static void __doLevel0 (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
143
150
{
144
- using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
145
- using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
146
- using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
147
- BinOp binop;
148
-
149
151
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
150
- subgroup2::inclusive_scan<params_lv0_t > inclusiveScan0;
152
+ subgroup2::inclusive_scan<Params > inclusiveScan0;
151
153
// level 0 scan
152
154
[unroll]
153
155
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
154
156
{
155
- vector_lv0_t value;
156
- dataAccessor.template get<vector_lv0_t , uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
157
+ vector_t value;
158
+ dataAccessor.template get<vector_t , uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
157
159
value = inclusiveScan0 (value);
158
- dataAccessor.template set<vector_lv0_t , uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
160
+ dataAccessor.template set<vector_t , uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
159
161
if (Config::electLast ())
160
162
{
161
163
const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(glsl::gl_SubgroupID (), idx);
162
164
scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
163
165
}
164
166
}
165
167
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
168
+ }
166
169
170
+ template<class DataAccessor, class ScratchAccessor>
171
+ void __call (NBL_REF_ARG (DataAccessor) dataAccessor, NBL_REF_ARG (ScratchAccessor) scratchAccessor)
172
+ {
173
+ using config_t = subgroup2::Configuration<Config::SubgroupSizeLog2>;
174
+ using params_lv0_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_0, device_capabilities>;
175
+ using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
176
+ BinOp binop;
177
+
178
+ __doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
179
+
180
+ const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
167
181
// level 1 scan
168
182
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
169
183
if (glsl::gl_SubgroupID () == 0 )
@@ -228,23 +242,9 @@ struct reduce<Config, BinOp, 3, device_capabilities>
228
242
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
229
243
BinOp binop;
230
244
231
- const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
232
- // level 0 scan
233
- subgroup2::reduction<params_lv0_t> reduction0;
234
- [unroll]
235
- for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
236
- {
237
- vector_lv0_t scan_local;
238
- dataAccessor.template get<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local);
239
- scan_local = reduction0 (scan_local);
240
- if (Config::electLast ())
241
- {
242
- const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(glsl::gl_SubgroupID (), idx);
243
- scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, scan_local[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
244
- }
245
- }
246
- scratchAccessor.workgroupExecutionAndMemoryBarrier ();
245
+ reduce<Config, BinOp, 2 , device_capabilities>::template __doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
247
246
247
+ const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
248
248
// level 1 scan
249
249
const uint32_t lv1_smem_size = Config::SubgroupsSize*Config::ItemsPerInvocation_1;
250
250
subgroup2::reduction<params_lv1_t> reduction1;
@@ -300,24 +300,9 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
300
300
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
301
301
BinOp binop;
302
302
303
- const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
304
- subgroup2::inclusive_scan<params_lv0_t> inclusiveScan0;
305
- // level 0 scan
306
- [unroll]
307
- for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
308
- {
309
- vector_lv0_t value;
310
- dataAccessor.template get<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
311
- value = inclusiveScan0 (value);
312
- dataAccessor.template set<vector_lv0_t, uint32_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
313
- if (Config::electLast ())
314
- {
315
- const uint32_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(glsl::gl_SubgroupID (), idx);
316
- scratchAccessor.template set<scalar_t, uint32_t>(bankedIndex, value[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
317
- }
318
- }
319
- scratchAccessor.workgroupExecutionAndMemoryBarrier ();
303
+ scan<Config, BinOp, Exclusive, 2 , device_capabilities>::template __doLevel0<DataAccessor, ScratchAccessor, params_lv0_t, vector_lv0_t>(dataAccessor, scratchAccessor);
320
304
305
+ const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
321
306
// level 1 scan
322
307
const uint32_t lv1_smem_size = Config::SubgroupsSize*Config::ItemsPerInvocation_1;
323
308
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
0 commit comments