Skip to content

Commit 0b16307

Browse files
committed
fix 3-level scan downsweep step
1 parent 203c03a commit 0b16307

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,15 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
305305
// level 1 scan
306306
const uint32_t lv1_smem_size = Config::__ItemsPerVirtualWorkgroup;
307307
const uint32_t lv1_num_invoc = Config::SubgroupSize*Config::ItemsPerInvocation_2;
308-
subgroup2::exclusive_scan<params_lv1_t> exclusiveScan1;
308+
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
309309
if (glsl::gl_SubgroupID() < lv1_num_invoc)
310310
{
311311
vector_lv1_t lv1_val;
312312
[unroll]
313313
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
314314
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
315315
// lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
316-
lv1_val = exclusiveScan1(lv1_val);
316+
lv1_val = inclusiveScan1(lv1_val);
317317
[unroll]
318318
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
319319
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
@@ -333,7 +333,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
333333
[unroll]
334334
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
335335
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
336-
lv2_val[0] = hlsl::mix(BinOp::identity, lv2_val[0], bool(invocationIndex));
336+
// lv2_val[0] = hlsl::mix(BinOp::identity, lv2_val[0], bool(invocationIndex));
337337
lv2_val = exclusiveScan2(lv2_val);
338338
[unroll]
339339
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
@@ -347,16 +347,20 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
347347
vector_lv1_t lv1_val;
348348
[unroll]
349349
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
350-
scratchAccessor.template get<scalar_t, uint32_t>(i*Config::SubgroupSize+invocationIndex,lv1_val[i]);
350+
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i), lv1_val[i]);
351+
352+
const scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(lv1_val[Config::ItemsPerInvocation_1-1],1), bool(glsl::gl_SubgroupInvocationID()));
351353

352354
scalar_t lv2_scan;
353355
const uint32_t bankedIndex = Config::template sharedStoreIndex<2>(glsl::gl_SubgroupID());
354-
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv2_scan);
356+
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+bankedIndex, lv2_scan);
355357

356358
[unroll]
357-
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
358-
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i), binop(lv1_val[i],lv2_scan));
359+
for (uint32_t i = Config::ItemsPerInvocation_1-1; i > 0; i--)
360+
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i), binop(lv1_val[i-1],lv2_scan));
361+
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, 0), binop(left_last_elem,lv2_scan));
359362
}
363+
scratchAccessor.workgroupExecutionAndMemoryBarrier();
360364

361365
// combine with level 0
362366
[unroll]

0 commit comments

Comments
 (0)