|
4 | 4 | #ifndef _NBL_BUILTIN_HLSL_WORKGROUP_ARITHMETIC_INCLUDED_
|
5 | 5 | #define _NBL_BUILTIN_HLSL_WORKGROUP_ARITHMETIC_INCLUDED_
|
6 | 6 |
|
7 |
| -#include "nbl/builtin/hlsl/cpp_compat.hlsl" |
| 7 | + |
8 | 8 | #include "nbl/builtin/hlsl/functional.hlsl"
|
9 | 9 | #include "nbl/builtin/hlsl/workgroup/ballot.hlsl"
|
10 | 10 | #include "nbl/builtin/hlsl/workgroup/broadcast.hlsl"
|
11 | 11 | #include "nbl/builtin/hlsl/workgroup/shared_scan.hlsl"
|
12 | 12 |
|
| 13 | + |
13 | 14 | namespace nbl
|
14 | 15 | {
|
15 | 16 | namespace hlsl
|
16 | 17 | {
|
17 | 18 | namespace workgroup
|
18 | 19 | {
|
19 | 20 |
|
20 |
| -#define REDUCE Reduce<T, subgroup::inclusive_scan<T, Binop>, SharedAccessor, _NBL_HLSL_WORKGROUP_SIZE_> |
21 |
| -#define SCAN(isExclusive) Scan<T, Binop, subgroup::inclusive_scan<T, Binop>, SharedAccessor, _NBL_HLSL_WORKGROUP_SIZE_, isExclusive> |
22 |
| -template<typename T, class Binop, class SharedAccessor> |
23 |
| -T reduction(T value, NBL_REF_ARG(SharedAccessor) accessor) |
| 21 | +// TODO: with Boost PP at some point |
| 22 | +//#define NBL_ALIAS_CALL_OPERATOR_TO_STATIC_IMPL(OPTIONAL_TEMPLATE,RETURN_TYPE,/*tuples of argument types and names*/...) |
| 23 | +//#define NBL_ALIAS_TEMPLATED_CALL_OPERATOR_TO_IMPL(TEMPLATE,RETURN_TYPE,/*tuples of argument types and names*/...) |
| 24 | + |
| 25 | +template<class BinOp, uint16_t ItemCount> |
| 26 | +struct reduction |
24 | 27 | {
|
25 |
| - REDUCE reduce = REDUCE::create(); |
26 |
| - reduce(value, accessor); |
27 |
| - accessor.main.workgroupExecutionAndMemoryBarrier(); |
28 |
| - T retVal = Broadcast<uint, SharedAccessor>(reduce.lastLevelScan, accessor, reduce.lastInvocationInLevel); |
29 |
| - return retVal; |
30 |
| -} |
| 28 | + using type_t = typename BinOp::type_t; |
31 | 29 |
|
32 |
| -template<typename T, class Binop, class SharedAccessor> |
33 |
| -T inclusive_scan(T value, NBL_REF_ARG(SharedAccessor) accessor) |
| 30 | + template<class Accessor> |
| 31 | + static type_t __call(NBL_CONST_REF_ARG(type_t) value, NBL_REF_ARG(Accessor) accessor) |
| 32 | + { |
| 33 | + impl::reduce<BinOp,ItemCount> fn; |
| 34 | + fn.template __call<Accessor>(value,accessor); |
| 35 | + accessor.workgroupExecutionAndMemoryBarrier(); |
| 36 | + return Broadcast<type_t,Accessor>(fn.lastLevelScan,accessor,fn.lastInvocationInLevel); |
| 37 | + } |
| 38 | +}; |
| 39 | + |
| 40 | +template<class BinOp, uint16_t ItemCount> |
| 41 | +struct inclusive_scan |
34 | 42 | {
|
35 |
| - SCAN(false) incl_scan = SCAN(false)::create(); |
36 |
| - T retVal = incl_scan(value, accessor); |
37 |
| - return retVal; |
38 |
| -} |
| 43 | + using type_t = typename BinOp::type_t; |
| 44 | + |
| 45 | + template<class Accessor> |
| 46 | + static type_t __call(NBL_CONST_REF_ARG(type_t) value, NBL_REF_ARG(Accessor) accessor) |
| 47 | + { |
| 48 | + impl::scan<BinOp,false,ItemCount> fn; |
| 49 | + return fn.template __call<Accessor>(value,accessor); |
| 50 | + } |
| 51 | +}; |
39 | 52 |
|
40 |
| -template<typename T, class Binop, class SharedAccessor> |
41 |
| -T exclusive_scan(T value, NBL_REF_ARG(SharedAccessor) accessor) |
| 53 | +template<class BinOp, uint16_t ItemCount> |
| 54 | +struct exclusive_scan |
42 | 55 | {
|
43 |
| - SCAN(true) excl_scan = SCAN(true)::create(); |
44 |
| - T retVal = excl_scan(value, accessor); |
45 |
| - return retVal; |
46 |
| -} |
| 56 | + using type_t = typename BinOp::type_t; |
47 | 57 |
|
48 |
| -#undef REDUCE |
49 |
| -#undef SCAN |
| 58 | + template<class Accessor> |
| 59 | + static type_t __call(NBL_CONST_REF_ARG(type_t) value, NBL_REF_ARG(Accessor) accessor) |
| 60 | + { |
| 61 | + impl::scan<BinOp,true,ItemCount> fn; |
| 62 | + return fn.template __call<Accessor>(value,accessor); |
| 63 | + } |
| 64 | +}; |
50 | 65 |
|
51 |
| -#define REDUCE Reduce<uint, subgroup::inclusive_scan<uint, plus<uint> >, SharedAccessor, impl::uballotBitfieldCount> |
52 |
| -#define SCAN Scan<uint, plus<uint>, subgroup::inclusive_scan<uint, plus<uint> >, SharedAccessor, impl::uballotBitfieldCount, true> |
53 | 66 | /**
|
54 |
| - * Gives us the sum (reduction) of all ballots for the workgroup. |
| 67 | + * Gives us the sum (reduction) of all ballots for the ItemCount bits of a workgroup. |
55 | 68 | *
|
56 | 69 | * Only the first few invocations are used for performing the sum
|
57 |
| - * since we only have `uballotBitfieldCount` amount of uints that we need |
| 70 | + * since we only have `1/32` amount of uints that we need |
58 | 71 | * to add together.
|
59 | 72 | *
|
60 | 73 | * We add them all in the shared array index after the last DWORD
|
61 | 74 | * that is used for the ballots. For example, if we have 128 workgroup size,
|
62 | 75 | * then the array index in which we accumulate the sum is `4` since
|
63 | 76 | * indexes 0..3 are used for ballots.
|
64 | 77 | */
|
65 |
| -template<class SharedAccessor> |
66 |
| -uint ballotBitCount(NBL_REF_ARG(SharedAccessor) accessor) |
| 78 | +namespace impl |
67 | 79 | {
|
68 |
| - uint participatingBitfield = 0; |
69 |
| - if(SubgroupContiguousIndex() < impl::uballotBitfieldCount) |
70 |
| - { |
71 |
| - participatingBitfield = accessor.ballot.get(SubgroupContiguousIndex()); |
72 |
| - } |
73 |
| - accessor.ballot.workgroupExecutionAndMemoryBarrier(); |
74 |
| - REDUCE reduce = REDUCE::create(); |
75 |
| - reduce(countbits(participatingBitfield), accessor); |
76 |
| - accessor.main.workgroupExecutionAndMemoryBarrier(); |
77 |
| - return Broadcast<uint, SharedAccessor>(reduce.lastLevelScan, accessor, reduce.lastInvocationInLevel); |
| 80 | +template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor, template<class,uint16_t> class op_t> |
| 81 | +uint32_t ballotPolyCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor, NBL_REF_ARG(uint32_t) localBitfield) |
| 82 | +{ |
| 83 | + localBitfield = 0u; |
| 84 | + if (SubgroupContiguousIndex()<impl::BallotDWORDCount(ItemCount)) |
| 85 | + localBitfield = ballotAccessor.get(SubgroupContiguousIndex()); |
| 86 | + return op_t<plus<uint32_t>,impl::ballot_dword_count<ItemCount>::value>::template __call<ArithmeticAccessor>(countbits(localBitfield),arithmeticAccessor); |
| 87 | +} |
78 | 88 | }
|
79 | 89 |
|
80 |
| -template<class SharedAccessor> |
81 |
| -uint ballotScanBitCount(const bool exclusive, NBL_REF_ARG(SharedAccessor) accessor) |
| 90 | +template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor> |
| 91 | +uint16_t ballotBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor) |
82 | 92 | {
|
83 |
| - const uint _dword = impl::getDWORD(SubgroupContiguousIndex()); |
84 |
| - const uint localBitfield = accessor.ballot.get(_dword); |
85 |
| - uint globalCount; |
86 |
| - { |
87 |
| - uint participatingBitfield; |
88 |
| - if(SubgroupContiguousIndex() < impl::uballotBitfieldCount) |
89 |
| - { |
90 |
| - participatingBitfield = accessor.ballot.get(SubgroupContiguousIndex()); |
91 |
| - } |
92 |
| - // scan hierarchically, invocations with `SubgroupContiguousIndex() >= uballotBitfieldCount` will have garbage here |
93 |
| - accessor.ballot.workgroupExecutionAndMemoryBarrier(); |
94 |
| - |
95 |
| - SCAN scan = SCAN::create(); |
96 |
| - uint bitscan = scan(countbits(participatingBitfield), accessor); |
97 |
| - |
98 |
| - accessor.main.set(SubgroupContiguousIndex(), bitscan); |
99 |
| - accessor.main.workgroupExecutionAndMemoryBarrier(); |
100 |
| - |
101 |
| - // fix it (abuse the fact memory is left over) |
102 |
| - globalCount = _dword != 0u ? accessor.main.get(_dword) : 0u; |
103 |
| - accessor.main.workgroupExecutionAndMemoryBarrier(); |
104 |
| - } |
105 |
| - const uint mask = (exclusive ? 0x7fFFffFFu:0xFFffFFffu)>>(31u-(SubgroupContiguousIndex()&31u)); |
106 |
| - return globalCount + countbits(localBitfield & mask); |
| 93 | + uint32_t dummy; |
| 94 | + return uint16_t(impl::ballotPolyCount<ItemCount,BallotAccessor,ArithmeticAccessor,reduction>(ballotAccessor,arithmeticAccessor,dummy)); |
107 | 95 | }
|
108 | 96 |
|
109 |
| -template<class SharedAccessor> |
110 |
| -uint ballotInclusiveBitCount(NBL_REF_ARG(SharedAccessor) accessor) |
| 97 | +template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor> |
| 98 | +uint16_t ballotInclusiveBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor) |
111 | 99 | {
|
112 |
| - return ballotScanBitCount<SharedAccessor>(false, accessor); |
| 100 | + uint32_t localBitfield; |
| 101 | + uint32_t count = impl::ballotPolyCount<ItemCount,BallotAccessor,ArithmeticAccessor,exclusive_scan>(ballotAccessor,arithmeticAccessor,localBitfield); |
| 102 | + // only using part of the mask is on purpose, I'm only interested in LSB |
| 103 | + return uint16_t(countbits(glsl::gl_SubgroupLeMask()[0]&localBitfield)+count); |
113 | 104 | }
|
114 | 105 |
|
115 |
| -template<class SharedAccessor> |
116 |
| -uint ballotExclusiveBitCount(NBL_REF_ARG(SharedAccessor) accessor) |
| 106 | +template<uint16_t ItemCount, class BallotAccessor, class ArithmeticAccessor> |
| 107 | +uint16_t ballotExclusiveBitCount(NBL_REF_ARG(BallotAccessor) ballotAccessor, NBL_REF_ARG(ArithmeticAccessor) arithmeticAccessor) |
117 | 108 | {
|
118 |
| - return ballotScanBitCount<SharedAccessor>(true, accessor); |
| 109 | + uint32_t localBitfield; |
| 110 | + uint32_t count = impl::ballotPolyCount<ItemCount,BallotAccessor,ArithmeticAccessor,exclusive_scan>(ballotAccessor,arithmeticAccessor,localBitfield); |
| 111 | + // only using part of the mask is on purpose, I'm only interested in LSB |
| 112 | + return uint16_t(countbits(glsl::gl_SubgroupLtMask()[0]&localBitfield)+count); |
119 | 113 | }
|
120 | 114 |
|
121 |
| -#undef REDUCE |
122 |
| -#undef SCAN |
123 |
| - |
124 | 115 | }
|
125 | 116 | }
|
126 | 117 | }
|
|
0 commit comments