@@ -179,15 +179,14 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
179
179
180
180
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
181
181
// level 1 scan
182
- subgroup2::exclusive_scan <params_lv1_t> exclusiveScan1 ;
182
+ subgroup2::inclusive_scan <params_lv1_t> inclusiveScan1 ;
183
183
if (glsl::gl_SubgroupID () == 0 )
184
184
{
185
185
vector_lv1_t lv1_val;
186
186
[unroll]
187
187
for (uint16_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
188
188
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i),lv1_val[i]);
189
- // lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
190
- lv1_val = exclusiveScan1 (lv1_val);
189
+ lv1_val = inclusiveScan1 (lv1_val);
191
190
[unroll]
192
191
for (uint16_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
193
192
scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i),lv1_val[i]);
@@ -201,9 +200,12 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
201
200
vector_lv0_t value;
202
201
dataAccessor.template get<vector_lv0_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
203
202
204
- const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(uint16_t (glsl::gl_SubgroupID ()), idx);
203
+ const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(uint16_t (glsl::gl_SubgroupID ()-1u ), idx);
205
204
scalar_t left;
206
- scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
205
+ if (idx != 0 || glsl::gl_SubgroupID () != 0 )
206
+ scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
207
+ else
208
+ left = BinOp::identity;
207
209
if (Exclusive)
208
210
{
209
211
scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
@@ -245,7 +247,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
245
247
246
248
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex ();
247
249
// level 1 scan
248
- const uint32_t lv1_smem_size = Config::__ItemsPerVirtualWorkgroup ;
250
+ const uint32_t lv1_smem_size = Config::LevelInputCount_1 ;
249
251
subgroup2::reduction<params_lv1_t> reduction1;
250
252
if (glsl::gl_SubgroupID () < Config::LevelInputCount_2)
251
253
{
@@ -311,7 +313,6 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
311
313
[unroll]
312
314
for (uint16_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
313
315
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i),lv1_val[i]);
314
- // lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
315
316
lv1_val = inclusiveScan1 (lv1_val);
316
317
[unroll]
317
318
for (uint16_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
@@ -325,15 +326,14 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
325
326
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
326
327
327
328
// level 2 scan
328
- subgroup2::exclusive_scan <params_lv2_t> exclusiveScan2 ;
329
+ subgroup2::inclusive_scan <params_lv2_t> inclusiveScan2 ;
329
330
if (glsl::gl_SubgroupID () == 0 )
330
331
{
331
332
vector_lv2_t lv2_val;
332
333
[unroll]
333
334
for (uint16_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
334
335
scratchAccessor.template get<scalar_t, uint16_t>(lv1_smem_size+Config::template sharedLoadIndex<2 >(invocationIndex, i),lv2_val[i]);
335
- // lv2_val[0] = hlsl::mix(BinOp::identity, lv2_val[0], bool(invocationIndex));
336
- lv2_val = exclusiveScan2 (lv2_val);
336
+ lv2_val = inclusiveScan2 (lv2_val);
337
337
[unroll]
338
338
for (uint16_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
339
339
scratchAccessor.template set<scalar_t, uint16_t>(lv1_smem_size+Config::template sharedLoadIndex<2 >(invocationIndex, i),lv2_val[i]);
@@ -344,20 +344,18 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
344
344
if (glsl::gl_SubgroupID () < Config::LevelInputCount_2)
345
345
{
346
346
vector_lv1_t lv1_val;
347
+ scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1 >(invocationIndex-uint16_t (1u), Config::ItemsPerInvocation_1-uint16_t (1u)), lv1_val[0 ]);
347
348
[unroll]
348
- for (uint16_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
349
- scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i), lv1_val[i]);
350
-
351
- const scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(lv1_val[Config::ItemsPerInvocation_1-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
349
+ for (uint16_t i = 1 ; i < Config::ItemsPerInvocation_1; i++)
350
+ scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i-uint16_t (1u)), lv1_val[i]);
352
351
353
352
scalar_t lv2_scan;
354
- const uint16_t bankedIndex = Config::template sharedStoreIndex<2 >(uint16_t (glsl::gl_SubgroupID ()));
353
+ const uint16_t bankedIndex = Config::template sharedStoreIndex<2 >(uint16_t (glsl::gl_SubgroupID ()-1u ));
355
354
scratchAccessor.template get<scalar_t, uint16_t>(lv1_smem_size+bankedIndex, lv2_scan);
356
355
357
356
[unroll]
358
- for (uint16_t i = Config::ItemsPerInvocation_1-1 ; i > 0 ; i--)
359
- scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i), binop (lv1_val[i-1 ],lv2_scan));
360
- scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<1 >(invocationIndex, 0 ), binop (left_last_elem,lv2_scan));
357
+ for (uint16_t i = 0 ; i < Config::ItemsPerInvocation_1; i--)
358
+ scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i), binop (lv1_val[i],lv2_scan));
361
359
}
362
360
scratchAccessor.workgroupExecutionAndMemoryBarrier ();
363
361
@@ -368,9 +366,12 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
368
366
vector_lv0_t value;
369
367
dataAccessor.template get<vector_lv0_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
370
368
371
- const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(glsl::gl_SubgroupID (), idx);
369
+ const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1 >(uint16_t ( glsl::gl_SubgroupID ()-1u ), idx);
372
370
scalar_t left;
373
- scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
371
+ if (idx != 0 || glsl::gl_SubgroupID () != 0 )
372
+ scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
373
+ else
374
+ left = BinOp::identity;
374
375
if (Exclusive)
375
376
{
376
377
scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
0 commit comments