Skip to content

Commit 90d3579

Browse files
committed
fix scans for level 1+
1 parent 127c6d9 commit 90d3579

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,15 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
179179

180180
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
181181
// level 1 scan
182-
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
182+
subgroup2::exclusive_scan<params_lv1_t> exclusiveScan1;
183183
if (glsl::gl_SubgroupID() == 0)
184184
{
185185
vector_lv1_t lv1_val;
186186
[unroll]
187187
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
188-
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i)-1,lv1_val[i]);
189-
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
190-
lv1_val = inclusiveScan1(lv1_val);
188+
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
189+
// lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
190+
lv1_val = exclusiveScan1(lv1_val);
191191
[unroll]
192192
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
193193
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
@@ -304,15 +304,16 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
304304
const uint32_t invocationIndex = workgroup::SubgroupContiguousIndex();
305305
// level 1 scan
306306
const uint32_t lv1_smem_size = Config::SubgroupsSize*Config::ItemsPerInvocation_1;
307-
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
308-
if (glsl::gl_SubgroupID() < Config::SubgroupsSize*Config::ItemsPerInvocation_2)
307+
const uint32_t lv1_num_invoc = Config::SubgroupsSize*Config::ItemsPerInvocation_2;
308+
subgroup2::exclusive_scan<params_lv1_t> exclusiveScan1;
309+
if (glsl::gl_SubgroupID() < lv1_num_invoc)
309310
{
310311
vector_lv1_t lv1_val;
311312
[unroll]
312313
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
313-
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i)-1,lv1_val[i]);
314-
lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
315-
lv1_val = inclusiveScan1(lv1_val);
314+
scratchAccessor.template get<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
315+
// lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
316+
lv1_val = exclusiveScan1(lv1_val);
316317
[unroll]
317318
for (uint32_t i = 0; i < Config::ItemsPerInvocation_1; i++)
318319
scratchAccessor.template set<scalar_t, uint32_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
@@ -325,23 +326,23 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
325326
scratchAccessor.workgroupExecutionAndMemoryBarrier();
326327

327328
// level 2 scan
328-
subgroup2::inclusive_scan<params_lv2_t> inclusiveScan2;
329+
subgroup2::exclusive_scan<params_lv2_t> exclusiveScan2;
329330
if (glsl::gl_SubgroupID() == 0)
330331
{
331332
vector_lv2_t lv2_val;
332333
[unroll]
333334
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
334-
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i)-1,lv2_val[i]);
335+
scratchAccessor.template get<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
335336
lv2_val[0] = hlsl::mix(BinOp::identity, lv2_val[0], bool(invocationIndex));
336-
lv2_val = inclusiveScan2(lv2_val);
337+
lv2_val = exclusiveScan2(lv2_val);
337338
[unroll]
338339
for (uint32_t i = 0; i < Config::ItemsPerInvocation_2; i++)
339340
scratchAccessor.template set<scalar_t, uint32_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
340341
}
341342
scratchAccessor.workgroupExecutionAndMemoryBarrier();
342343

343344
// combine with level 1
344-
if (glsl::gl_SubgroupID() < lv1_smem_size)
345+
if (glsl::gl_SubgroupID() < lv1_num_invoc)
345346
{
346347
vector_lv1_t lv1_val;
347348
[unroll]

0 commit comments

Comments
 (0)