Skip to content

Commit 867103c

Browse files
rewrite workgroup ballot
1 parent 7dc35b1 commit 867103c

File tree

2 files changed

+47
-24
lines changed

2 files changed

+47
-24
lines changed

include/nbl/builtin/hlsl/workgroup/arithmetic.hlsl

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,39 +77,54 @@ struct exclusive_scan
7777
*/
7878
namespace impl
7979
{
80-
template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor, template<class,uint16_t> class op_t>
81-
uint32_t ballotPolyCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor, NBL_REF_ARG(uint32_t) localBitfield)
80+
template<uint16_t DWORDCount, class BallotAccessor>
81+
uint16_t ballotCountedBitDWORD(NBL_REF_ARG(BallotAccessor) ballotAccessor)
8282
{
83-
localBitfield = 0u;
84-
if (SubgroupContiguousIndex()<impl::BallotDWORDCount(ItemCount))
85-
localBitfield = ballotAccessor.get(SubgroupContiguousIndex());
86-
return op_t<plus<uint32_t>,impl::ballot_dword_count<ItemCount>::value>::template __call<ArithmeticAccessor>(countbits(localBitfield),arithmeticAccessor);
83+
const uint32_t index = SubgroupContiguousIndex();
84+
if (index<DWORDCount)
85+
{
86+
const uint32_t bitfield = ballotAccessor.get(index);
87+
// FIXME: stip unused bits from bitfield
88+
return uint16_t(countbits(bitfield));
89+
}
90+
return 0;
91+
}
92+
93+
template<bool Exclusive, uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor>
94+
uint16_t ballotScanBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor)
95+
{
96+
const uint32_t localBitfield = ballotAccessor.get(impl::getDWORD(SubgroupContiguousIndex()));
97+
98+
static const uint16_t DWORDCount = impl::ballot_dword_count<ItemCount>::value;
99+
const uint32_t count = exclusive_scan<plus<uint32_t>,DWORDCount>::template __call<ArithmeticAccessor>(
100+
ballotCountedBitDWORD<DWORDCount,BallotAccessor>(ballotAccessor),
101+
arithmeticAccessor
102+
);
103+
return uint16_t(countbits(localBitfield&(Exclusive ? glsl::gl_SubgroupLtMask():glsl::gl_SubgroupLeMask())[0]));
104+
// return uint16_t(countbits(localBitfield&(Exclusive ? glsl::gl_SubgroupLtMask():glsl::gl_SubgroupLeMask())[0])+count);
87105
}
88106
}
89107

90108
template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor>
91109
uint16_t ballotBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor)
92110
{
93-
uint32_t dummy;
94-
return uint16_t(impl::ballotPolyCount<ItemCount,BallotAccessor,ArithmeticAccessor,reduction>(ballotAccessor,arithmeticAccessor,dummy));
111+
static const uint16_t DWORDCount = impl::ballot_dword_count<ItemCount>::value;
112+
return uint16_t(reduction<plus<uint32_t>,DWORDCount>::template __call<ArithmeticAccessor>(
113+
impl::ballotCountedBitDWORD<DWORDCount,BallotAccessor>(ballotAccessor),
114+
arithmeticAccessor
115+
));
95116
}
96117

97118
template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor>
98119
uint16_t ballotInclusiveBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor)
99120
{
100-
uint32_t localBitfield;
101-
uint32_t count = impl::ballotPolyCount<ItemCount,BallotAccessor,ArithmeticAccessor,exclusive_scan>(ballotAccessor,arithmeticAccessor,localBitfield);
102-
// only using part of the mask is on purpose, I'm only interested in LSB
103-
return uint16_t(countbits(glsl::gl_SubgroupLeMask()[0]&localBitfield)+count);
121+
return impl::ballotScanBitCount<false,ItemCount,BallotAccessor,ArithmeticAccessor>(ballotAccessor,arithmeticAccessor);
104122
}
105123

106124
template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor>
107125
uint16_t ballotExclusiveBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor)
108126
{
109-
uint32_t localBitfield;
110-
uint32_t count = impl::ballotPolyCount<ItemCount,BallotAccessor,ArithmeticAccessor,exclusive_scan>(ballotAccessor,arithmeticAccessor,localBitfield);
111-
// only using part of the mask is on purpose, I'm only interested in LSB
112-
return uint16_t(countbits(glsl::gl_SubgroupLtMask()[0]&localBitfield)+count);
127+
return impl::ballotScanBitCount<true,ItemCount,BallotAccessor,ArithmeticAccessor>(ballotAccessor,arithmeticAccessor);
113128
}
114129

115130
}

include/nbl/builtin/hlsl/workgroup/ballot.hlsl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,26 @@ struct ballot_dword_count : integral_constant<uint16_t,((ItemCount+31)>>5)> {};
4949
* For example, for a workgroup size 128, 4 DWORDs are needed.
5050
* For each invocation index, we can find its respective DWORD index in the accessor array
5151
* by calling the getDWORD function.
52+
*
53+
* TODO: try do it with 64bit ints instead? (requires modified/adapted accessor)
5254
*/
5355
template<class Accessor>
5456
void ballot(const bool value, NBL_REF_ARG(Accessor) accessor)
5557
{
56-
const uint16_t index = SubgroupContiguousIndex();
57-
const bool initialize = index<impl::BallotDWORDCount(Volume());
58-
if (initialize)
59-
accessor.set(index,0u);
60-
61-
accessor.workgroupExecutionAndMemoryBarrier();
62-
if(value)
63-
accessor.atomicOr(impl::getDWORD(index),1u<<(index&31u));
58+
const uint32_t4 bitfield = glsl::subgroupBallot(value);
59+
60+
const uint16_t subgroupInvocation = uint16_t(glsl::gl_SubgroupInvocationID());
61+
uint16_t destIx = subgroupInvocation;
62+
63+
const uint16_t SubgroupSizeLog2 = uint16_t(glsl::gl_SubgroupSizeLog2());
64+
if (SubgroupSizeLog2>=5)
65+
destIx += uint16_t(glsl::gl_SubgroupID())<<(SubgroupSizeLog2-5);
66+
else
67+
destIx += uint16_t(glsl::gl_SubgroupID())>>(5-SubgroupSizeLog2);
68+
69+
const uint16_t UsefulComponents = impl::getDWORD(uint16_t(glsl::gl_SubgroupSize()));
70+
if (subgroupInvocation<UsefulComponents)
71+
accessor.set(destIx,bitfield[subgroupInvocation]);
6472
}
6573

6674
template<class Accessor>

0 commit comments

Comments
 (0)