@@ -81,12 +81,13 @@ template<uint16_t ItemCount, class BallotAccessor>
81
81
uint16_t ballotCountedBitDWORD (NBL_REF_ARG (BallotAccessor) ballotAccessor)
82
82
{
83
83
const uint32_t index = SubgroupContiguousIndex ();
84
- if (index<impl::ballot_dword_count<ItemCount>::value)
84
+ static const uint16_t DWORDCount = impl::ballot_dword_count<ItemCount>::value;
85
+ if (index<DWORDCount)
85
86
{
86
87
uint32_t bitfield = ballotAccessor.get (index);
87
- // strip unwanted bits from bitfield
88
+ // strip unwanted bits from bitfield of the last item
88
89
const uint16_t Remainder = ItemCount&31 ;
89
- if (Remainder!=0 )
90
+ if (Remainder!=0 && index==DWORDCount- 1 )
90
91
bitfield &= (0x1u<<Remainder)-1 ;
91
92
return uint16_t (countbits (bitfield));
92
93
}
@@ -96,15 +97,21 @@ uint16_t ballotCountedBitDWORD(NBL_REF_ARG(BallotAccessor) ballotAccessor)
96
97
template<bool Exclusive, uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor>
97
98
uint16_t ballotScanBitCount (NBL_REF_ARG (BallotAccessor) ballotAccessor, NBL_REF_ARG (ArithmeticAccessor) arithmeticAccessor)
98
99
{
99
- const uint32_t localBitfield = ballotAccessor.get (impl::getDWORD (SubgroupContiguousIndex ()));
100
+ const uint16_t subgroupIndex = SubgroupContiguousIndex ();
101
+ const uint16_t bitfieldIndex = impl::getDWORD (subgroupIndex);
102
+ const uint32_t localBitfield = ballotAccessor.get (bitfieldIndex);
100
103
101
104
static const uint16_t DWORDCount = impl::ballot_dword_count<ItemCount>::value;
102
- const uint32_t count = exclusive_scan<plus<uint32_t>,DWORDCount>::template __call<ArithmeticAccessor>(
105
+ uint32_t count = exclusive_scan<plus<uint32_t>,DWORDCount>::template __call<ArithmeticAccessor>(
103
106
ballotCountedBitDWORD<ItemCount,BallotAccessor>(ballotAccessor),
104
107
arithmeticAccessor
105
108
);
106
- return uint16_t (countbits (localBitfield&(Exclusive ? glsl::gl_SubgroupLtMask ():glsl::gl_SubgroupLeMask ())[0 ]));
107
- // return uint16_t(countbits(localBitfield&(Exclusive ? glsl::gl_SubgroupLtMask():glsl::gl_SubgroupLeMask())[0])+count);
109
+ arithmeticAccessor.workgroupExecutionAndMemoryBarrier ();
110
+ if (subgroupIndex<DWORDCount)
111
+ arithmeticAccessor.set (subgroupIndex,count);
112
+ arithmeticAccessor.workgroupExecutionAndMemoryBarrier ();
113
+ count = arithmeticAccessor.get (bitfieldIndex);
114
+ return uint16_t (countbits (localBitfield&(Exclusive ? glsl::gl_SubgroupLtMask ():glsl::gl_SubgroupLeMask ())[0 ])+count);
108
115
}
109
116
}
110
117
0 commit comments