Skip to content

Commit 3da175d

Browse files
committed
padding to shared mem indexing to avoid bank conflict
1 parent e230d06 commit 3da175d

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

include/nbl/builtin/hlsl/workgroup2/arithmetic_config.hlsl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ struct ArithmeticConfiguration
7070
NBL_CONSTEXPR_STATIC_INLINE uint32_t SharedScratchElementCount = conditional_value<LevelCount==1,uint16_t,
7171
0,
7272
conditional_value<LevelCount==3,uint16_t,
73-
LevelInputCount_2,
73+
LevelInputCount_2+(SubgroupSize*ItemsPerInvocation_1)-1,
7474
0
7575
>::value + LevelInputCount_1
7676
>::value;
77+
NBL_CONSTEXPR_STATIC_INLINE uint16_t __padding = conditional_value<LevelCount==3,uint16_t,SubgroupSize-1,0>::value;
7778

7879
static bool electLast()
7980
{
@@ -90,40 +91,42 @@ struct ArithmeticConfiguration
9091
// get a coalesced index to store for the next level in shared mem, e.g. level 0 -> level 1
9192
// specify the next level to store values for in template param
9293
// at level==LevelCount-1, it is guaranteed to have SubgroupSize elements
93-
template<uint16_t level NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
94+
template<uint16_t level>// NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
9495
static uint16_t sharedStoreIndex(const uint16_t virtualSubgroupID)
9596
{
96-
uint16_t offsetBySubgroup;
97+
uint16_t nextLevelInvocationCount;
9798
if (level == LevelCount-1)
98-
offsetBySubgroup = SubgroupSize;
99+
nextLevelInvocationCount = SubgroupSize;
99100
else
100-
offsetBySubgroup = __SubgroupsPerVirtualWorkgroup;
101+
nextLevelInvocationCount = __SubgroupsPerVirtualWorkgroup;
101102

102103
if (level==2)
103-
return LevelInputCount_1 + (virtualSubgroupID & (ItemsPerInvocation_2-uint16_t(1u))) * offsetBySubgroup + (virtualSubgroupID/ItemsPerInvocation_2);
104+
return LevelInputCount_1 + ((SubgroupSize-uint16_t(1u))*ItemsPerInvocation_1) + (virtualSubgroupID & (ItemsPerInvocation_2-uint16_t(1u))) * nextLevelInvocationCount + (virtualSubgroupID/ItemsPerInvocation_2);
104105
else
105-
return (virtualSubgroupID & (ItemsPerInvocation_1-uint16_t(1u))) * offsetBySubgroup + (virtualSubgroupID/ItemsPerInvocation_1);
106+
return (virtualSubgroupID & (ItemsPerInvocation_1-uint16_t(1u))) * (nextLevelInvocationCount+__padding) + (virtualSubgroupID/ItemsPerInvocation_1) + virtualSubgroupID/(SubgroupSize*ItemsPerInvocation_1);
106107
}
107108

108-
template<uint16_t level NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
109+
template<uint16_t level>// NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
109110
static uint16_t sharedStoreIndexFromVirtualIndex(const uint16_t subgroupID, const uint16_t workgroupInVirtualIndex)
110111
{
111112
const uint16_t virtualID = virtualSubgroupID(subgroupID, workgroupInVirtualIndex);
112113
return sharedStoreIndex<level>(virtualID);
113114
}
114115

115116
// get the coalesced index in shared mem at the current level
116-
template<uint16_t level NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
117+
template<uint16_t level>// NBL_FUNC_REQUIRES(level>0 && level<LevelCount)
117118
static uint16_t sharedLoadIndex(const uint16_t invocationIndex, const uint16_t component)
118119
{
119-
uint16_t smem_offset = 0u;
120-
if (level == 2)
121-
smem_offset += LevelInputCount_1;
122-
120+
uint16_t levelInvocationCount;
123121
if (level == LevelCount-1)
124-
return component * SubgroupSize + invocationIndex + smem_offset;
122+
levelInvocationCount = SubgroupSize;
123+
else
124+
levelInvocationCount = __SubgroupsPerVirtualWorkgroup;
125+
126+
if (level==2)
127+
return LevelInputCount_1 + ((SubgroupSize-uint16_t(1u))*ItemsPerInvocation_1) + component * levelInvocationCount + invocationIndex + invocationIndex/SubgroupSize;
125128
else
126-
return component * __SubgroupsPerVirtualWorkgroup + invocationIndex;
129+
return component * (levelInvocationCount+__padding) + invocationIndex + invocationIndex/SubgroupSize;
127130
}
128131
};
129132

include/nbl/builtin/hlsl/workgroup2/shared_scan.hlsl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,6 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
330330
[unroll]
331331
for (uint16_t i = 0; i < Config::ItemsPerInvocation_1; i++)
332332
scratchAccessor.template set<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(invocationIndex, i),lv1_val[i]);
333-
if (Config::electLast())
334-
{
335-
const uint16_t bankedIndex = Config::template sharedStoreIndex<2>(uint16_t(glsl::gl_SubgroupID()));
336-
scratchAccessor.template set<scalar_t, uint16_t>(bankedIndex, lv1_val[Config::ItemsPerInvocation_1-1]);
337-
}
338333
}
339334
scratchAccessor.workgroupExecutionAndMemoryBarrier();
340335

@@ -345,7 +340,7 @@ struct scan<Config, BinOp, Exclusive, 3, device_capabilities>
345340
vector_lv2_t lv2_val;
346341
[unroll]
347342
for (uint16_t i = 0; i < Config::ItemsPerInvocation_2; i++)
348-
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<2>(invocationIndex, i),lv2_val[i]);
343+
scratchAccessor.template get<scalar_t, uint16_t>(Config::template sharedLoadIndex<1>(((invocationIndex*Config::ItemsPerInvocation_1)+i+1)*Config::SubgroupSize-1, Config::ItemsPerInvocation_1-1),lv2_val[i]);
349344
lv2_val = inclusiveScan2(lv2_val);
350345
[unroll]
351346
for (uint16_t i = 0; i < Config::ItemsPerInvocation_2; i++)

0 commit comments

Comments
 (0)