Skip to content

Commit 1b6fa5f

Browse files
fix final bugs in ballot scans
1 parent b638080 commit 1b6fa5f

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

include/nbl/builtin/hlsl/workgroup/arithmetic.hlsl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,13 @@ template<uint16_t ItemCount, class BallotAccessor>
8181
uint16_t ballotCountedBitDWORD(NBL_REF_ARG(BallotAccessor) ballotAccessor)
8282
{
8383
const uint32_t index = SubgroupContiguousIndex();
84-
if (index<impl::ballot_dword_count<ItemCount>::value)
84+
static const uint16_t DWORDCount = impl::ballot_dword_count<ItemCount>::value;
85+
if (index<DWORDCount)
8586
{
8687
uint32_t bitfield = ballotAccessor.get(index);
87-
// strip unwanted bits from bitfield
88+
// strip unwanted bits from bitfield of the last item
8889
const uint16_t Remainder = ItemCount&31;
89-
if (Remainder!=0)
90+
if (Remainder!=0 && index==DWORDCount-1)
9091
bitfield &= (0x1u<<Remainder)-1;
9192
return uint16_t(countbits(bitfield));
9293
}
@@ -96,15 +97,21 @@ uint16_t ballotCountedBitDWORD(NBL_REF_ARG(BallotAccessor) ballotAccessor)
9697
template<bool Exclusive, uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor>
9798
uint16_t ballotScanBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor)
9899
{
99-
const uint32_t localBitfield = ballotAccessor.get(impl::getDWORD(SubgroupContiguousIndex()));
100+
const uint16_t subgroupIndex = SubgroupContiguousIndex();
101+
const uint16_t bitfieldIndex = impl::getDWORD(subgroupIndex);
102+
const uint32_t localBitfield = ballotAccessor.get(bitfieldIndex);
100103

101104
static const uint16_t DWORDCount = impl::ballot_dword_count<ItemCount>::value;
102-
const uint32_t count = exclusive_scan<plus<uint32_t>,DWORDCount>::template __call<ArithmeticAccessor>(
105+
uint32_t count = exclusive_scan<plus<uint32_t>,DWORDCount>::template __call<ArithmeticAccessor>(
103106
ballotCountedBitDWORD<ItemCount,BallotAccessor>(ballotAccessor),
104107
arithmeticAccessor
105108
);
106-
return uint16_t(countbits(localBitfield&(Exclusive ? glsl::gl_SubgroupLtMask():glsl::gl_SubgroupLeMask())[0]));
107-
// return uint16_t(countbits(localBitfield&(Exclusive ? glsl::gl_SubgroupLtMask():glsl::gl_SubgroupLeMask())[0])+count);
109+
arithmeticAccessor.workgroupExecutionAndMemoryBarrier();
110+
if (subgroupIndex<DWORDCount)
111+
arithmeticAccessor.set(subgroupIndex,count);
112+
arithmeticAccessor.workgroupExecutionAndMemoryBarrier();
113+
count = arithmeticAccessor.get(bitfieldIndex);
114+
return uint16_t(countbits(localBitfield&(Exclusive ? glsl::gl_SubgroupLtMask():glsl::gl_SubgroupLeMask())[0])+count);
108115
}
109116
}
110117

0 commit comments

Comments
 (0)