@@ -305,15 +305,15 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
305
305
// level 1 scan
306
306
const uint32_t lv1_smem_size = Config::__ItemsPerVirtualWorkgroup;
307
307
const uint32_t lv1_num_invoc = Config::SubgroupSize*Config::ItemsPerInvocation_2;
308
- subgroup2::exclusive_scan <params_lv1_t> exclusiveScan1 ;
308
+ subgroup2::inclusive_scan <params_lv1_t> inclusiveScan1 ;
309
309
if (glsl::gl_SubgroupID () < lv1_num_invoc)
310
310
{
311
311
vector_lv1_t lv1_val;
312
312
[unroll]
313
313
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
314
314
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i),lv1_val[i]);
315
315
// lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
316
- lv1_val = exclusiveScan1 (lv1_val);
316
+ lv1_val = inclusiveScan1 (lv1_val);
317
317
[unroll]
318
318
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
319
319
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i),lv1_val[i]);
@@ -333,7 +333,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
333
333
[unroll]
334
334
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
335
335
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2 >(invocationIndex, i),lv2_val[i]);
336
- lv2_val[0 ] = hlsl::mix (BinOp::identity, lv2_val[0 ], bool (invocationIndex));
336
+ // lv2_val[0] = hlsl::mix(BinOp::identity, lv2_val[0], bool(invocationIndex));
337
337
lv2_val = exclusiveScan2 (lv2_val);
338
338
[unroll]
339
339
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_2; i++)
@@ -347,16 +347,20 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
347
347
vector_lv1_t lv1_val;
348
348
[unroll]
349
349
for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
350
- scratchAccessor.template get<scalar_t, uint32_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
350
+ scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i), lv1_val[i]);
351
+
352
+ const scalar_t left_last_elem = hlsl::mix (BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(lv1_val[Config::ItemsPerInvocation_1-1 ],1 ), bool (glsl::gl_SubgroupInvocationID ()));
351
353
352
354
scalar_t lv2_scan;
353
355
const uint32_t bankedIndex = Config::template sharedStoreIndex<2 >(glsl::gl_SubgroupID ());
354
- scratchAccessor.template set <scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv2_scan);
356
+ scratchAccessor.template get <scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv2_scan);
355
357
356
358
[unroll]
357
- for (uint32_t i = 0 ; i < Config::ItemsPerInvocation_1; i++)
358
- scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i), binop (lv1_val[i],lv2_scan));
359
+ for (uint32_t i = Config::ItemsPerInvocation_1-1 ; i > 0 ; i--)
360
+ scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, i), binop (lv1_val[i-1 ],lv2_scan));
361
+ scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1 >(invocationIndex, 0 ), binop (left_last_elem,lv2_scan));
359
362
}
363
+ scratchAccessor.workgroupExecutionAndMemoryBarrier ();
360
364
361
365
// combine with level 0
362
366
[unroll]
0 commit comments