Skip to content

Commit 7b15a54

Browse files
committed
do inclusive scan on upsweep and shift left on downsweep
1 parent 7d77d30 commit 7b15a54

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,14 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
179179

180180
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
181181
// level 1 scan
182-
subgroup2::exclusive_scan<params_lv1_t> exclusiveScan1;
182+
subgroup2::inclusive_scan<params_lv1_t> inclusiveScan1;
183183
if (glsl::gl_SubgroupID() == 0)
184184
{
185185
vector_lv1_t lv1_val;
186186
[unroll]
187187
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i++)
188188
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
189-
// lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
190-
lv1_val = exclusiveScan1(lv1_val);
189+
lv1_val = inclusiveScan1(lv1_val);
191190
[unroll]
192191
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i++)
193192
scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
@@ -201,9 +200,12 @@ struct scan<Config, BinOp, Exclusive, 2, device_capabilities>
201200
vector_lv0_t value;
202201
dataAccessor.template get<vector_lv0_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
203202

204-
const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(uint16_t(glsl::gl_SubgroupID()), idx);
203+
const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(uint16_t(glsl::gl_SubgroupID()-1u), idx);
205204
scalar_t left;
206-
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
205+
if (idx != 0 || glsl::gl_SubgroupID() != 0)
206+
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
207+
else
208+
left = BinOp::identity;
207209
if (Exclusive)
208210
{
209211
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));
@@ -245,7 +247,7 @@ struct reduce<Config, BinOp, 3, device_capabilities>
245247

246248
const uint16_t invocationIndex = workgroup::SubgroupContiguousIndex();
247249
// level 1 scan
248-
const uint32_t lv1_smem_size = Config::__ItemsPerVirtualWorkgroup;
250+
const uint32_t lv1_smem_size = Config::LevelInputCount_1;
249251
subgroup2::reduction<params_lv1_t> reduction1;
250252
if (glsl::gl_SubgroupID() < Config::LevelInputCount_2)
251253
{
@@ -311,7 +313,6 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
311313
[unroll]
312314
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i++)
313315
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
314-
// lv1_val[0] = hlsl::mix(BinOp::identity, lv1_val[0], bool(invocationIndex));
315316
lv1_val = inclusiveScan1(lv1_val);
316317
[unroll]
317318
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i++)
@@ -325,15 +326,14 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
325326
scratchAccessor.workgroupExecutionAndMemoryBarrier();
326327

327328
// level 2 scan
328-
subgroup2::exclusive_scan<params_lv2_t> exclusiveScan2;
329+
subgroup2::inclusive_scan<params_lv2_t> inclusiveScan2;
329330
if (glsl::gl_SubgroupID() == 0)
330331
{
331332
vector_lv2_t lv2_val;
332333
[unroll]
333334
for (uint16_t i = 0; i < Config::ItemsPerInvocation_2; i++)
334335
scratchAccessor.template get<scalar_t, uint16_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
335-
// lv2_val[0] = hlsl::mix(BinOp::identity, lv2_val[0], bool(invocationIndex));
336-
lv2_val = exclusiveScan2(lv2_val);
336+
lv2_val = inclusiveScan2(lv2_val);
337337
[unroll]
338338
for (uint16_t i = 0; i < Config::ItemsPerInvocation_2; i++)
339339
scratchAccessor.template set<scalar_t, uint16_t>(lv1_smem_size+Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
@@ -344,20 +344,18 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
344344
if (glsl::gl_SubgroupID() < Config::LevelInputCount_2)
345345
{
346346
vector_lv1_t lv1_val;
347+
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex-uint16_t(1u), Config::ItemsPerInvocation_1-uint16_t(1u)), lv1_val[0]);
347348
[unroll]
348-
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i++)
349-
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i), lv1_val[i]);
350-
351-
const scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(lv1_val[Config::ItemsPerInvocation_1-1],1), bool(glsl::gl_SubgroupInvocationID()));
349+
for (uint16_t i = 1; i < Config::ItemsPerInvocation_1; i++)
350+
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i-uint16_t(1u)), lv1_val[i]);
352351

353352
scalar_t lv2_scan;
354-
const uint16_t bankedIndex = Config::template sharedStoreIndex<2>(uint16_t(glsl::gl_SubgroupID()));
353+
const uint16_t bankedIndex = Config::template sharedStoreIndex<2>(uint16_t(glsl::gl_SubgroupID()-1u));
355354
scratchAccessor.template get<scalar_t, uint16_t>(lv1_smem_size+bankedIndex, lv2_scan);
356355

357356
[unroll]
358-
for (uint16_t i = Config::ItemsPerInvocation_1-1; i > 0; i--)
359-
scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i), binop(lv1_val[i-1],lv2_scan));
360-
scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, 0), binop(left_last_elem,lv2_scan));
357+
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i--)
358+
scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i), binop(lv1_val[i],lv2_scan));
361359
}
362360
scratchAccessor.workgroupExecutionAndMemoryBarrier();
363361

@@ -368,9 +366,12 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
368366
vector_lv0_t value;
369367
dataAccessor.template get<vector_lv0_t, uint16_t>(idx * Config::WorkgroupSize + virtualInvocationIndex, value);
370368

371-
const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(glsl::gl_SubgroupID(), idx);
369+
const uint16_t bankedIndex = Config::template sharedStoreIndexFromVirtualIndex<1>(uint16_t(glsl::gl_SubgroupID()-1u), idx);
372370
scalar_t left;
373-
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
371+
if (idx != 0 || glsl::gl_SubgroupID() != 0)
372+
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex,left);
373+
else
374+
left = BinOp::identity;
374375
if (Exclusive)
375376
{
376377
scalar_t left_last_elem = hlsl::mix(BinOp::identity, glsl::subgroupShuffleUp<scalar_t>(value[Config::ItemsPerInvocation_0-1],1), bool(glsl::gl_SubgroupInvocationID()));

0 commit comments

Comments
 (0)