Skip to content

Commit 6b33b97

Browse files
ok at least ItemCount==WorkgroupSize work as reductions
1 parent 8469d0f commit 6b33b97

File tree

3 files changed

+139
-158
lines changed

3 files changed

+139
-158
lines changed

include/nbl/builtin/hlsl/workgroup/arithmetic.hlsl

Lines changed: 67 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -4,123 +4,114 @@
44
#ifndef _NBL_BUILTIN_HLSL_WORKGROUP_ARITHMETIC_INCLUDED_
55
#define _NBL_BUILTIN_HLSL_WORKGROUP_ARITHMETIC_INCLUDED_
66

7-
#include "nbl/builtin/hlsl/cpp_compat.hlsl"
7+
88
#include "nbl/builtin/hlsl/functional.hlsl"
99
#include "nbl/builtin/hlsl/workgroup/ballot.hlsl"
1010
#include "nbl/builtin/hlsl/workgroup/broadcast.hlsl"
1111
#include "nbl/builtin/hlsl/workgroup/shared_scan.hlsl"
1212

13+
1314
namespace nbl
1415
{
1516
namespace hlsl
1617
{
1718
namespace workgroup
1819
{
1920

20-
#define REDUCE Reduce<T, subgroup::inclusive_scan<T, Binop>, SharedAccessor, _NBL_HLSL_WORKGROUP_SIZE_>
21-
#define SCAN(isExclusive) Scan<T, Binop, subgroup::inclusive_scan<T, Binop>, SharedAccessor, _NBL_HLSL_WORKGROUP_SIZE_, isExclusive>
22-
template<typename T, class Binop, class SharedAccessor>
23-
T reduction(T value, NBL_REF_ARG(SharedAccessor) accessor)
21+
// TODO: with Boost PP at some point
22+
//#define NBL_ALIAS_CALL_OPERATOR_TO_STATIC_IMPL(OPTIONAL_TEMPLATE,RETURN_TYPE,/*tuples of argument types and names*/...)
23+
//#define NBL_ALIAS_TEMPLATED_CALL_OPERATOR_TO_IMPL(TEMPLATE,RETURN_TYPE,/*tuples of argument types and names*/...)
24+
25+
template<class BinOp, uint16_t ItemCount>
26+
struct reduction
2427
{
25-
REDUCE reduce = REDUCE::create();
26-
reduce(value, accessor);
27-
accessor.main.workgroupExecutionAndMemoryBarrier();
28-
T retVal = Broadcast<uint, SharedAccessor>(reduce.lastLevelScan, accessor, reduce.lastInvocationInLevel);
29-
return retVal;
30-
}
28+
using type_t = typename BinOp::type_t;
3129

32-
template<typename T, class Binop, class SharedAccessor>
33-
T inclusive_scan(T value, NBL_REF_ARG(SharedAccessor) accessor)
30+
template<class Accessor>
31+
static type_t __call(NBL_CONST_REF_ARG(type_t) value, NBL_REF_ARG(Accessor) accessor)
32+
{
33+
impl::reduce<BinOp,ItemCount> fn;
34+
fn.template __call<Accessor>(value,accessor);
35+
accessor.workgroupExecutionAndMemoryBarrier();
36+
return Broadcast<type_t,Accessor>(fn.lastLevelScan,accessor,fn.lastInvocationInLevel);
37+
}
38+
};
39+
40+
template<class BinOp, uint16_t ItemCount>
41+
struct inclusive_scan
3442
{
35-
SCAN(false) incl_scan = SCAN(false)::create();
36-
T retVal = incl_scan(value, accessor);
37-
return retVal;
38-
}
43+
using type_t = typename BinOp::type_t;
44+
45+
template<class Accessor>
46+
static type_t __call(NBL_CONST_REF_ARG(type_t) value, NBL_REF_ARG(Accessor) accessor)
47+
{
48+
impl::scan<BinOp,false,ItemCount> fn;
49+
return fn.template __call<Accessor>(value,accessor);
50+
}
51+
};
3952

40-
template<typename T, class Binop, class SharedAccessor>
41-
T exclusive_scan(T value, NBL_REF_ARG(SharedAccessor) accessor)
53+
template<class BinOp, uint16_t ItemCount>
54+
struct exclusive_scan
4255
{
43-
SCAN(true) excl_scan = SCAN(true)::create();
44-
T retVal = excl_scan(value, accessor);
45-
return retVal;
46-
}
56+
using type_t = typename BinOp::type_t;
4757

48-
#undef REDUCE
49-
#undef SCAN
58+
template<class Accessor>
59+
static type_t __call(NBL_CONST_REF_ARG(type_t) value, NBL_REF_ARG(Accessor) accessor)
60+
{
61+
impl::scan<BinOp,true,ItemCount> fn;
62+
return fn.template __call<Accessor>(value,accessor);
63+
}
64+
};
5065

51-
#define REDUCE Reduce<uint, subgroup::inclusive_scan<uint, plus<uint> >, SharedAccessor, impl::uballotBitfieldCount>
52-
#define SCAN Scan<uint, plus<uint>, subgroup::inclusive_scan<uint, plus<uint> >, SharedAccessor, impl::uballotBitfieldCount, true>
5366
/**
54-
* Gives us the sum (reduction) of all ballots for the workgroup.
67+
* Gives us the sum (reduction) of all ballots for the ItemCount bits of a workgroup.
5568
*
5669
* Only the first few invocations are used for performing the sum
57-
* since we only have `uballotBitfieldCount` amount of uints that we need
70+
* since we only have `1/32` amount of uints that we need
5871
* to add together.
5972
*
6073
* We add them all in the shared array index after the last DWORD
6174
* that is used for the ballots. For example, if we have 128 workgroup size,
6275
* then the array index in which we accumulate the sum is `4` since
6376
* indexes 0..3 are used for ballots.
6477
*/
65-
template<class SharedAccessor>
66-
uint ballotBitCount(NBL_REF_ARG(SharedAccessor) accessor)
78+
namespace impl
6779
{
68-
uint participatingBitfield = 0;
69-
if(SubgroupContiguousIndex() < impl::uballotBitfieldCount)
70-
{
71-
participatingBitfield = accessor.ballot.get(SubgroupContiguousIndex());
72-
}
73-
accessor.ballot.workgroupExecutionAndMemoryBarrier();
74-
REDUCE reduce = REDUCE::create();
75-
reduce(countbits(participatingBitfield), accessor);
76-
accessor.main.workgroupExecutionAndMemoryBarrier();
77-
return Broadcast<uint, SharedAccessor>(reduce.lastLevelScan, accessor, reduce.lastInvocationInLevel);
80+
template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor, template<class,uint16_t> class op_t>
81+
uint32_t ballotPolyCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor, NBL_REF_ARG(uint32_t) localBitfield)
82+
{
83+
localBitfield = 0u;
84+
if (SubgroupContiguousIndex()<impl::BallotDWORDCount(ItemCount))
85+
localBitfield = ballotAccessor.get(SubgroupContiguousIndex());
86+
return op_t<plus<uint32_t>,impl::ballot_dword_count<ItemCount>::value>::template __call<ArithmeticAccessor>(countbits(localBitfield),arithmeticAccessor);
87+
}
7888
}
7989

80-
template<class SharedAccessor>
81-
uint ballotScanBitCount(const bool exclusive, NBL_REF_ARG(SharedAccessor) accessor)
90+
template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor>
91+
uint16_t ballotBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor)
8292
{
83-
const uint _dword = impl::getDWORD(SubgroupContiguousIndex());
84-
const uint localBitfield = accessor.ballot.get(_dword);
85-
uint globalCount;
86-
{
87-
uint participatingBitfield;
88-
if(SubgroupContiguousIndex() < impl::uballotBitfieldCount)
89-
{
90-
participatingBitfield = accessor.ballot.get(SubgroupContiguousIndex());
91-
}
92-
// scan hierarchically, invocations with `SubgroupContiguousIndex() >= uballotBitfieldCount` will have garbage here
93-
accessor.ballot.workgroupExecutionAndMemoryBarrier();
94-
95-
SCAN scan = SCAN::create();
96-
uint bitscan = scan(countbits(participatingBitfield), accessor);
97-
98-
accessor.main.set(SubgroupContiguousIndex(), bitscan);
99-
accessor.main.workgroupExecutionAndMemoryBarrier();
100-
101-
// fix it (abuse the fact memory is left over)
102-
globalCount = _dword != 0u ? accessor.main.get(_dword) : 0u;
103-
accessor.main.workgroupExecutionAndMemoryBarrier();
104-
}
105-
const uint mask = (exclusive ? 0x7fFFffFFu:0xFFffFFffu)>>(31u-(SubgroupContiguousIndex()&31u));
106-
return globalCount + countbits(localBitfield & mask);
93+
uint32_t dummy;
94+
return uint16_t(impl::ballotPolyCount<ItemCount,BallotAccessor,ArithmeticAccessor,reduction>(ballotAccessor,arithmeticAccessor,dummy));
10795
}
10896

109-
template<class SharedAccessor>
110-
uint ballotInclusiveBitCount(NBL_REF_ARG(SharedAccessor) accessor)
97+
template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor>
98+
uint16_t ballotInclusiveBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor)
11199
{
112-
return ballotScanBitCount<SharedAccessor>(false, accessor);
100+
uint32_t localBitfield;
101+
uint32_t count = impl::ballotPolyCount<ItemCount,BallotAccessor,ArithmeticAccessor,exclusive_scan>(ballotAccessor,arithmeticAccessor,localBitfield);
102+
// only using part of the mask is on purpose, I'm only interested in LSB
103+
return uint16_t(countbits(glsl::gl_SubgroupLeMask()[0]&localBitfield)+count);
113104
}
114105

115-
template<class SharedAccessor>
116-
uint ballotExclusiveBitCount(NBL_REF_ARG(SharedAccessor) accessor)
106+
template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor>
107+
uint16_t ballotExclusiveBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor)
117108
{
118-
return ballotScanBitCount<SharedAccessor>(true, accessor);
109+
uint32_t localBitfield;
110+
uint32_t count = impl::ballotPolyCount<ItemCount,BallotAccessor,ArithmeticAccessor,exclusive_scan>(ballotAccessor,arithmeticAccessor,localBitfield);
111+
// only using part of the mask is on purpose, I'm only interested in LSB
112+
return uint16_t(countbits(glsl::gl_SubgroupLtMask()[0]&localBitfield)+count);
119113
}
120114

121-
#undef REDUCE
122-
#undef SCAN
123-
124115
}
125116
}
126117
}

0 commit comments

Comments
 (0)