Skip to content

Commit 89099f7

Browse files
- add gl_WorkGroupSize
- add a HLSL `assert` - clean up everything except for workgroup arithmetic - reduce reliance on NBL_WORKGROUP_SIZE define
1 parent 15e8d52 commit 89099f7

File tree

8 files changed

+76
-40
lines changed

8 files changed

+76
-40
lines changed

include/nbl/builtin/hlsl/glsl_compat/core.hlsl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ T atomicCompSwap(NBL_REF_ARG(T) ptr, T comparator, T value)
6363
// TODO (Future): Its annoying we have to forward declare those, but accessing gl_NumSubgroups and other gl_* values is not yet possible due to https://github.com/microsoft/DirectXShaderCompiler/issues/4217
6464
// also https://github.com/microsoft/DirectXShaderCompiler/issues/5280
6565
uint32_t gl_LocalInvocationIndex();
66+
uint32_t3 gl_WorkGroupSize();
6667
uint32_t3 gl_GlobalInvocationID();
6768
uint32_t3 gl_WorkGroupID();
6869

include/nbl/builtin/hlsl/macros.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66

77
#ifdef __HLSL_VERSION
88
#define static_assert(...) _Static_assert(__VA_ARGS__)
9+
#define assert(expr) \
10+
{ \
11+
bool con = (expr); \
12+
do { \
13+
[branch] if (!con) \
14+
vk::RawBufferStore<uint32_t>(0xdeadbeefBADC0FFbull,0x45u,4u); \
15+
} while(!con); \
16+
}
917
#endif
1018

1119
// basics

include/nbl/builtin/hlsl/workgroup/ballot.hlsl

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,23 @@ namespace hlsl
1414
{
1515
namespace workgroup
1616
{
17+
1718
namespace impl
1819
{
19-
uint getDWORD(uint invocation)
20+
uint16_t getDWORD(uint16_t invocation)
2021
{
21-
return invocation >> 5;
22+
uint16_t dword = invocation>>5;
23+
assert(dword<((Volume()+31)>>5));
24+
return dword; // log2 of sizeof(uint32_t)*8
2225
}
2326

24-
// uballotBitfieldCount essentially means 'how many DWORDs are needed to store ballots in bitfields, for each invocation of the workgroup'
25-
// can't use getDWORD because we want the static const to be treated as 'constexpr'
26-
static const uint uballotBitfieldCount = (_NBL_HLSL_WORKGROUP_SIZE_+31) >> 5; // in case WGSZ is not a multiple of 32 we might miscalculate the DWORDs after the right-shift by 5 which is why we add 31
27-
27+
// essentially means 'how many DWORDs are needed to store ballots in bitfields, for each invocation of `itemCount`
28+
uint16_t BallotDWORDCount(const uint16_t itemCount)
29+
{
30+
return getDWORD(itemCount+31); // round up, in case all items don't fit in even number of DWORDs
2831
}
32+
}
33+
2934
/**
3035
* Simple ballot function.
3136
*
@@ -34,42 +39,44 @@ static const uint uballotBitfieldCount = (_NBL_HLSL_WORKGROUP_SIZE_+31) >> 5; //
3439
* then the Uint will be ...00100000
3540
* This way we can encode 32 invocations into a single Uint.
3641
*
37-
* All Uints are kept in contiguous accessor memory in a shared array.
38-
* The size of that array is based on the WORKGROUP SIZE. In this case we use uballotBitfieldCount.
42+
* All Uints are kept in contiguous accessor memory in an array (shared is best).
43+
* The size of that array is based on the ItemCount.
3944
*
4045
* For each group of 32 invocations, a DWORD is assigned to the array (i.e. a 32-bit value, in this case Uint).
4146
* For example, for a workgroup size 128, 4 DWORDs are needed.
4247
* For each invocation index, we can find its respective DWORD index in the accessor array
4348
* by calling the getDWORD function.
4449
*/
45-
template<class SharedAccessor>
46-
void ballot(const bool value, NBL_REF_ARG(SharedAccessor) accessor)
50+
template<class Accessor>
51+
void ballot(const bool value, NBL_REF_ARG(Accessor) accessor)
4752
{
48-
uint initialize = SubgroupContiguousIndex() < impl::uballotBitfieldCount;
49-
if(initialize) {
50-
accessor.ballot.set(SubgroupContiguousIndex(), 0u);
51-
}
53+
const uint32_t index = SubgroupContiguousIndex();
54+
const bool initialize = index<BallotDWORDCount(Volume());
55+
if (initialize)
56+
accessor.set(index,0u);
57+
5258
accessor.ballot.workgroupExecutionAndMemoryBarrier();
53-
if(value) {
54-
uint dummy;
55-
accessor.ballot.atomicOr(impl::getDWORD(SubgroupContiguousIndex()), 1u<<(SubgroupContiguousIndex()&31u), dummy);
59+
if(value)
60+
{
61+
uint32_t dummy;
62+
accessor.atomicOr(impl::getDWORD(index),1u<<(index&31u),dummy);
5663
}
5764
}
5865

59-
template<class SharedAccessor>
60-
bool ballotBitExtract(const uint index, NBL_REF_ARG(SharedAccessor) accessor)
66+
template<class Accessor>
67+
bool ballotBitExtract(const uint32_t index, NBL_REF_ARG(Accessor) accessor)
6168
{
62-
return (accessor.ballot.get(impl::getDWORD(index)) & (1u << (index & 31u))) != 0u;
69+
return bool(accessor.get(impl::getDWORD(index))&(1u<<(index&31u)));
6370
}
6471

6572
/**
6673
* Once we have assigned ballots in the shared array, we can
6774
* extract any invocation's ballot value using this function.
6875
*/
69-
template<class SharedAccessor>
70-
bool inverseBallot(NBL_REF_ARG(SharedAccessor) accessor)
76+
template<class Accessor>
77+
bool inverseBallot(NBL_REF_ARG(Accessor) accessor)
7178
{
72-
return ballotBitExtract<SharedAccessor>(SubgroupContiguousIndex(), accessor);
79+
return ballotBitExtract<Accessor>(SubgroupContiguousIndex(),accessor);
7380
}
7481

7582
}

include/nbl/builtin/hlsl/workgroup/basic.hlsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ namespace workgroup
1616

1717
static const uint32_t MaxWorkgroupSizeLog2 = 11;
1818
static const uint32_t MaxWorkgroupSize = 0x1u<<MaxWorkgroupSizeLog2;
19+
20+
uint32_t Volume()
21+
{
22+
const uint32_t3 dims = glsl::gl_WorkGroupSize();
23+
return dims.x*dims.y*dims.z;
24+
}
1925

2026
uint32_t SubgroupContiguousIndex()
2127
{

include/nbl/builtin/hlsl/workgroup/broadcast.hlsl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,21 @@ namespace workgroup
2121
* We save the value in the shared array in the uballotBitfieldCount index
2222
* and then all invocations access that index.
2323
*/
24-
template<typename T, class SharedAccessor>
25-
T Broadcast(const T val, NBL_REF_ARG(SharedAccessor) accessor, const uint id)
24+
template<typename T, class Accessor>
25+
T Broadcast(NBL_CONST_REF_ARG(T) val, NBL_REF_ARG(Accessor) accessor, const uint32_t id)
2626
{
27-
if(SubgroupContiguousIndex() == id) {
28-
accessor.broadcast.set(impl::uballotBitfieldCount, val);
29-
}
27+
if(SubgroupContiguousIndex()==id)
28+
accessor.set(0,val);
3029

31-
accessor.broadcast.workgroupExecutionAndMemoryBarrier();
30+
accessor.workgroupExecutionAndMemoryBarrier();
3231

33-
return accessor.broadcast.get(impl::uballotBitfieldCount);
32+
return accessor.get(0);
3433
}
3534

36-
template<typename T, class SharedAccessor>
37-
T BroadcastFirst(const T val, NBL_REF_ARG(SharedAccessor) accessor)
35+
template<typename T, class Accessor>
36+
T BroadcastFirst(NBL_CONST_REF_ARG(T) val, NBL_REF_ARG(Accessor) accessor)
3837
{
39-
if (Elect())
40-
accessor.broadcast.set(impl::uballotBitfieldCount, val);
41-
42-
accessor.broadcast.workgroupExecutionAndMemoryBarrier();
43-
44-
return accessor.broadcast.get(impl::uballotBitfieldCount);
38+
return Broadcast(val,0);
4539
}
4640

4741
}

include/nbl/builtin/hlsl/workgroup/scratch_sz.hlsl renamed to include/nbl/builtin/hlsl/workgroup/scratch_size.hlsl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88
#include "nbl/builtin/hlsl/type_traits.hlsl"
99

10-
// REVIEW-519: Review this whole header and content (whether it should be here or somewhere else)
10+
1111
namespace nbl
1212
{
1313
namespace hlsl
1414
{
1515
namespace workgroup
1616
{
17+
1718
namespace impl
1819
{
1920
template<uint32_t N, uint32_t K>
@@ -28,6 +29,20 @@ struct trunc_geom_series<N,K,W,false> : integral_constant<uint32_t,ceil_div<N,W>
2829
template<uint32_t N, uint32_t K, uint32_t W>
2930
struct trunc_geom_series<N,K,W,true> : integral_constant<uint32_t,0> {};
3031
}
32+
33+
template<uint16_t ContiguousItemCount>
34+
struct scratch_size_ballot
35+
{
36+
NBL_CONSTEXPR_STATIC_INLINE uint16_t value = (ContiguousItemCount+31)>>5;
37+
};
38+
39+
// you're only writing one element
40+
NBL_CONSTEXPR_STATIC scratch_size_broadcast = 1u;
41+
42+
// if you know better you can use the actual subgroup size
43+
template<uint16_t ContiguousItemCount, uint16_t SubgroupSize=subgroup::MinSubgroupSize>
44+
struct scratch_size_arithmetic : impl::trunc_geom_series<ContiguousItemCount,SubgroupSize> {};
45+
3146
}
3247
}
3348
}

src/nbl/builtin/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,12 @@ LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/basic.hlsl")
271271
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/arithmetic_portability.hlsl")
272272
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/subgroup/arithmetic_portability_impl.hlsl")
273273
#workgroup
274+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/arithmetic.hlsl")
274275
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/basic.hlsl")
276+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/ballot.hlsl")
277+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/broadcast.hlsl")
278+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/scratch_size.hlsl")
279+
LIST_BUILTIN_RESOURCE(NBL_RESOURCES_TO_EMBED "hlsl/workgroup/shared_Scan.hlsl")
275280

276281

277282
macro(NBL_ADD_BUILTIN_RESOURCES _TARGET_) # internal & Nabla only, must be added with the macro to properly propagate scope

0 commit comments

Comments
 (0)