@@ -151,20 +151,21 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
151
151
using params_lv1_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_1, device_capabilities>;
152
152
BinOp binop;
153
153
154
- vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize];
155
154
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
156
155
subgroup2::inclusive_scan<params_lv0_t> inclusiveScan0;
157
156
// level 0 scan
158
157
[unroll]
159
158
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
160
159
{
161
- dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
162
- scan_local[idx] = inclusiveScan0 (scan_local[idx]);
160
+ vector_lv0_t value;
161
+ dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
162
+ value = inclusiveScan0 (value);
163
+ dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
163
164
if (glsl::gl_SubgroupInvocationID ()==Config::SubgroupSize-1 )
164
165
{
165
166
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID ();
166
167
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1 )) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
167
- scratchAccessor.template set<scalar_t>(bankedIndex, scan_local[idx] [Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
168
+ scratchAccessor.template set<scalar_t>(bankedIndex, value [Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
168
169
}
169
170
}
170
171
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
@@ -188,23 +189,26 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
188
189
[unroll]
189
190
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
190
191
{
192
+ vector_lv0_t value;
193
+ dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
194
+
191
195
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID ();
192
196
scalar_t left;
193
197
scratchAccessor.template get<scalar_t>(virtualSubgroupID,left);
194
198
if (Exclusive)
195
199
{
196
- scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(scan_local[idx] [Config::ItemsPerInvocation_0-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
200
+ scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value [Config::ItemsPerInvocation_0-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
197
201
[unroll]
198
202
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_0; i++)
199
- scan_local[idx][ Config::ItemsPerInvocation_0-i-1 ] = binop (left, hlsl::mix (scan_local[idx] [Config::ItemsPerInvocation_0-i-2 ], left_last_elem, (Config::ItemsPerInvocation_0-i-1 ==0 )));
203
+ value[ Config::ItemsPerInvocation_0-i-1 ] = binop (left, hlsl::mix (value [Config::ItemsPerInvocation_0-i-2 ], left_last_elem, (Config::ItemsPerInvocation_0-i-1 ==0 )));
200
204
}
201
205
else
202
206
{
203
207
[unroll]
204
208
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_0; i++)
205
- scan_local[idx][ i] = binop (left, scan_local[idx] [i]);
209
+ value[ i] = binop (left, value [i]);
206
210
}
207
- dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx] );
211
+ dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value );
208
212
}
209
213
}
210
214
};
@@ -303,20 +307,21 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
303
307
using params_lv2_t = subgroup2::ArithmeticParams<config_t, BinOp, Config::ItemsPerInvocation_2, device_capabilities>;
304
308
BinOp binop;
305
309
306
- vector_lv0_t scan_local[Config::VirtualWorkgroupSize / Config::WorkgroupSize];
307
310
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex ();
308
311
subgroup2::inclusive_scan<params_lv0_t> inclusiveScan0;
309
312
// level 0 scan
310
313
[unroll]
311
314
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
312
315
{
313
- dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx]);
314
- scan_local[idx] = inclusiveScan0 (scan_local[idx]);
316
+ vector_lv0_t value;
317
+ dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
318
+ value = inclusiveScan0 (value);
319
+ dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
315
320
if (glsl::gl_SubgroupInvocationID ()==Config::SubgroupSize-1 )
316
321
{
317
322
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID ();
318
323
const uint32_t bankedIndex = (virtualSubgroupID & (Config::ItemsPerInvocation_1-1 )) * Config::SubgroupsPerVirtualWorkgroup + (virtualSubgroupID/Config::ItemsPerInvocation_1);
319
- scratchAccessor.template set<scalar_t>(bankedIndex, scan_local[idx] [Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
324
+ scratchAccessor.template set<scalar_t>(bankedIndex, value [Config::ItemsPerInvocation_0-1 ]); // set last element of subgroup scan (reduction) to level 1 scan
320
325
}
321
326
}
322
327
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
@@ -368,23 +373,26 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
368
373
[unroll]
369
374
for (uint32_t idx = 0 , virtualInvocationIndex = invocationIndex; idx < Config::VirtualWorkgroupSize / Config::WorkgroupSize; idx++)
370
375
{
376
+ vector_lv0_t value;
377
+ dataAccessor.template get<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
378
+
371
379
const uint32_t virtualSubgroupID = idx * (Config::WorkgroupSize >> Config::SubgroupSizeLog2) + glsl::gl_SubgroupID ();
372
380
const scalar_t left;
373
381
scratchAccessor.template get<scalar_t>(virtualSubgroupID, left);
374
382
if (Exclusive)
375
383
{
376
- scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(scan_local[idx] [Config::ItemsPerInvocation_0-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
384
+ scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value [Config::ItemsPerInvocation_0-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
377
385
[unroll]
378
386
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_0; i++)
379
- scan_local[idx][ Config::ItemsPerInvocation_0-i-1 ] = binop (left, hlsl::mix (scan_local[idx] [Config::ItemsPerInvocation_0-i-2 ], left_last_elem, (Config::ItemsPerInvocation_0-i-1 ==0 )));
387
+ value[ Config::ItemsPerInvocation_0-i-1 ] = binop (left, hlsl::mix (value [Config::ItemsPerInvocation_0-i-2 ], left_last_elem, (Config::ItemsPerInvocation_0-i-1 ==0 )));
380
388
}
381
389
else
382
390
{
383
391
[unroll]
384
392
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_0; i++)
385
- scan_local[idx][ i] = binop (left, scan_local[idx] [i]);
393
+ value[ i] = binop (left, value [i]);
386
394
}
387
- dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, scan_local[idx] );
395
+ dataAccessor.template set<vector_lv0_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value );
388
396
}
389
397
}
390
398
};
0 commit comments