Skip to content

Commit e230d06

Browse files
committed
fixes to 3 level scan
1 parent da6c313 commit e230d06

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,12 @@ struct ArithmeticConfiguration
116116
template<uint16_t level NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
117117
static uint16_t sharedLoadIndex(const uint16_t invocationIndex, const uint16_t component)
118118
{
119+
uint16_t smem_offset = 0u;
120+
if (level == 2)
121+
smem_offset += LevelInputCount_1;
122+
119123
if (level == LevelCount-1)
120-
return component * SubgroupSize + invocationIndex;
124+
return component * SubgroupSize + invocationIndex + smem_offset;
121125
else
122126
return component * __SubgroupsPerVirtualWorkgroup + invocationIndex;
123127
}

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,17 +357,19 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
357357
if (glsl::gl_SubgroupID() < Config::LevelInputCount_2)
358358
{
359359
vector_lv1_t lv1_val;
360-
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex-uint16_t(1u), Config::ItemsPerInvocation_1-uint16_t(1u)), lv1_val[0]);
361360
[unroll]
362-
for (uint16_t i = 1; i < Config::ItemsPerInvocation_1; i++)
363-
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i-uint16_t(1u)), lv1_val[i]);
361+
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i++)
362+
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i), lv1_val[i]);
364363

365364
scalar_t lv2_scan;
366365
const uint16_t bankedIndex = Config::template sharedStoreIndex<2>(uint16_t(glsl::gl_SubgroupID()-1u));
367-
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex, lv2_scan);
366+
if (glsl::gl_SubgroupID() != 0)
367+
scratchAccessor.template get<scalar_t, uint16_t>(bankedIndex, lv2_scan);
368+
else
369+
lv2_scan = BinOp::identity;
368370

369371
[unroll]
370-
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i--)
372+
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i++)
371373
scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i), binop(lv1_val[i],lv2_scan));
372374
}
373375
scratchAccessor.workgroupExecutionAndMemoryBarrier();

0 commit comments

Comments
 (0)