@@ -94,20 +94,20 @@ struct reduce<Config, BinOp, 2, device_capabilities>
94
94
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
95
95
BinOp binop;
96
96
97
- vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize];
98
97
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
99
98
// level 0 scan
100
99
subgroup2::reduction<params_lv0_t> reduction0;
101
100
[unroll]
102
101
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
103
102
{
104
- dataAccessor.get (idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
105
- scan_local[idx] = reduction0 (scan_local[idx]);
103
+ vector_lv0_t scan_local;
104
+ dataAccessor.get (idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local);
105
+ scan_local = reduction0 (scan_local);
106
106
if (glsl::gl_SubgroupInvocationID ()==Config::SubgroupSize-1 )
107
107
{
108
108
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID ();
109
109
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1 )) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
110
- scratchAccessor.set (bankedIndex, scan_local[idx][ Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
110
+ scratchAccessor.set (bankedIndex, scan_local[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
111
111
}
112
112
}
113
113
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
@@ -227,20 +227,20 @@ struct reduce<Config, BinOp, 3, device_capabilities>
227
227
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
228
228
BinOp binop;
229
229
230
- vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize];
231
230
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
232
231
// level 0 scan
233
232
subgroup2::reduction<params_lv0_t> reduction0;
234
233
[unroll]
235
234
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
236
235
{
237
- dataAccessor.get (idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
238
- scan_local[idx] = reduction0 (scan_local[idx]);
236
+ vector_lv0_t scan_local;
237
+ dataAccessor.get (idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local);
238
+ scan_local = reduction0 (scan_local);
239
239
if (glsl::gl_SubgroupInvocationID ()==Config::SubgroupSize-1 )
240
240
{
241
241
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID ();
242
242
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1 )) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
243
- scratchAccessor.set (bankedIndex, scan_local[idx][ Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
243
+ scratchAccessor.set (bankedIndex, scan_local[Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
244
244
}
245
245
}
246
246
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
0 commit comments