Skip to content

Commit e03b142

Browse files
committed
Add subgroup shuffle functionality and use in alternate implementation for scan ops
1 parent 3782a34 commit e03b142

File tree

4 files changed

+205
-134
lines changed

4 files changed

+205
-134
lines changed

include/nbl/builtin/hlsl/subgroup/arithmetic_portability.hlsl

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#include <nbl/builtin/hlsl/subgroup/basic_portability.hlsl>
66
#endif
77

8-
static const uint4 WHOLE_WAVE = ~0; // REVIEW: Confirm this is proper placement and definition point
9-
108
namespace nbl
119
{
1210
namespace hlsl
@@ -28,17 +26,29 @@ struct inclusive_scan;
2826
}
2927
#endif
3028

29+
namespace portability
30+
{
31+
32+
// PORTABILITY BINOP DECLARATIONS
33+
template<class Binop, class ScratchAccessor>
34+
struct reduction;
35+
template<class Binop, class ScratchAccessor>
36+
struct inclusive_scan;
37+
template<class Binop, class ScratchAccessor>
38+
struct exclusive_scan;
39+
40+
}
41+
3142
template<class Binop>
3243
struct reduction
3344
{
3445
template<class ScratchAccessor, typename T>
3546
T operator()(const T x)
36-
{
47+
{ // REVIEW: Should these extension headers have the GL name?
3748
#ifdef NBL_GL_KHR_shader_subgroup_arithmetic
3849
return native::reduction<Binop>()(x);
3950
#else
40-
portability::reduction<Binop,ScratchAccessor> impl;
41-
return impl(x);
51+
return portability::reduction<Binop,ScratchAccessor>::create()(x);
4252
#endif
4353
}
4454
};
@@ -52,8 +62,7 @@ struct exclusive_scan
5262
#ifdef NBL_GL_KHR_shader_subgroup_arithmetic
5363
return native::exclusive_scan<Binop>()(x);
5464
#else
55-
portability::exclusive_scan<Binop,ScratchAccessor> impl;
56-
return impl(x);
65+
portability::exclusive_scan<Binop,ScratchAccessor>::create()(x);
5766
#endif
5867
}
5968
};
@@ -67,25 +76,11 @@ struct inclusive_scan
6776
#ifdef NBL_GL_KHR_shader_subgroup_arithmetic
6877
return native::inclusive_scan<Binop>()(x);
6978
#else
70-
portability::inclusive_scan<Binop,ScratchAccessor> impl;
71-
return impl(x);
79+
portability::inclusive_scan<Binop,ScratchAccessor>::create()(x);
7280
#endif
7381
}
7482
};
7583

76-
namespace portability
77-
{
78-
79-
// PORTABILITY BINOP DECLARATIONS
80-
template<class Binop, class ScratchAccessor>
81-
struct reduction;
82-
template<class Binop, class ScratchAccessor>
83-
struct inclusive_scan;
84-
template<class Binop, class ScratchAccessor>
85-
struct exclusive_scan;
86-
87-
}
88-
8984
}
9085
}
9186
}

include/nbl/builtin/hlsl/subgroup/arithmetic_portability_impl.hlsl

Lines changed: 117 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#ifndef _NBL_BUILTIN_HLSL_SUBGROUP_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
22
#define _NBL_BUILTIN_HLSL_SUBGROUP_ARITHMETIC_PORTABILITY_IMPL_INCLUDED_
33

4-
uint localInvocationIndex : SV_GroupIndex; // REVIEW: Discuss proper placement of SV_* values. They are not allowed to be defined inside a function scope, only as arguments of global variables in the shader.
4+
#define WHOLE_WAVE ~0
5+
6+
uint gl_LocalInvocationIndex : SV_GroupIndex; // REVIEW: Discuss proper placement of SV_* values. They are not allowed to be defined inside a function scope, only as arguments of global variables in the shader.
57

68
namespace nbl
79
{
@@ -117,7 +119,7 @@ struct exclusive_scan<binops::bitwise_add>
117119
template<typename T>
118120
T operator()(const T x)
119121
{
120-
return WaveMultiPrefixSum(x, WHOLE_WAVE);
122+
return WavePrefixSum(x);
121123
}
122124
};
123125
template<>
@@ -126,7 +128,7 @@ struct inclusive_scan<binops::bitwise_add>
126128
template<typename T>
127129
T operator()(const T x)
128130
{
129-
return WaveMultiPrefixSum(x, WHOLE_WAVE) + x;
131+
return WavePrefixSum(x) + x;
130132
}
131133
};
132134

@@ -146,7 +148,7 @@ struct exclusive_scan<binops::bitwise_mul>
146148
template<typename T>
147149
T operator()(const T x)
148150
{
149-
return WaveMultiPrefixProduct(x, WHOLE_WAVE);
151+
return WavePrefixProduct(x);
150152
}
151153
};
152154
template<>
@@ -155,7 +157,7 @@ struct inclusive_scan<binops::bitwise_mul>
155157
template<typename T>
156158
T operator()(const T x)
157159
{
158-
return WaveMultiPrefixProduct(x, WHOLE_WAVE) * x;
160+
return WavePrefixProduct(x) * x;
159161
}
160162
};
161163

@@ -264,81 +266,89 @@ struct ScratchAccessorAdaptor {
264266

265267
struct scan_base
266268
{
267-
// even if you have a `const uint nbl::hlsl::subgroup::Size` it wont work I think, so `#define` needed
268-
static const uint SubgroupSize = nbl::hlsl::subgroup::subgroupSize();
269-
static const uint HalfSubgroupSize = SubgroupSize>>1u; // REVIEW: Is this ok?
270-
static const uint LoMask = SubgroupSize-1u;
271-
static const uint LastWorkgroupInvocation = _NBL_HLSL_WORKGROUP_SIZE_-1; // REVIEW: Where should this be defined?
272-
static const uint pseudoSubgroupInvocation = localInvocationIndex&LoMask; // Also used in substructs, thus static const
269+
// even if you have a `const uint nbl::hlsl::subgroup::Size` it wont work I think, so `#define` needed
270+
static const uint SubgroupSize = nbl::hlsl::subgroup::Size();
271+
static const uint HalfSubgroupSize = SubgroupSize>>1u; // REVIEW: Is this ok?
272+
static const uint LoMask = SubgroupSize-1u;
273+
static const uint LastWorkgroupInvocation = _NBL_HLSL_WORKGROUP_SIZE_-1; // REVIEW: Where should this be defined?
274+
static const uint pseudoSubgroupInvocation = gl_LocalInvocationIndex&LoMask; // Also used in substructs, thus static const
273275

274-
static inclusive_scan<Binop,ScratchAccessor> create()
275-
{
276-
const uint pseudoSubgroupElectedInvocation = localInvocationIndex&(~LoMask);
277-
278-
inclusive_scan<Binop,ScratchAccessor> retval;
279-
280-
const uint subgroupMemoryBegin = pseudoSubgroupElectedInvocation<<1u;
281-
retval.lastLoadOffset = subgroupMemoryBegin+pseudoSubgroupInvocation;
282-
283-
const uint paddingMemoryEnd = subgroupMemoryBegin+HalfSubgroupSize;
284-
retval.scanStoreOffset = paddingMemoryEnd+pseudoSubgroupInvocation;
285-
286-
uint reductionResultOffset = paddingMemoryEnd;
287-
if ((LastWorkgroupInvocation>>nbl::hlsl::subgroup::subgroupSizeLog2())!=nbl::hlsl::subgroup::subgroupInvocationID())
288-
retval.reductionResultOffset += LastWorkgroupInvocation&LoMask;
289-
else
290-
retval.reductionResultOffset += LoMask;
291-
292-
retval.paddingMemoryEnd = reductionResultOffset;
293-
294-
return retval;
276+
static inclusive_scan<Binop,ScratchAccessor> create()
277+
{
278+
const uint pseudoSubgroupElectedInvocation = gl_LocalInvocationIndex&(~LoMask);
279+
280+
inclusive_scan<Binop,ScratchAccessor> retval;
281+
282+
const uint subgroupMemoryBegin = pseudoSubgroupElectedInvocation<<1u;
283+
retval.lastLoadOffset = subgroupMemoryBegin+pseudoSubgroupInvocation;
284+
retval.paddingMemoryEnd = subgroupMemoryBegin+HalfSubgroupSize;
285+
retval.scanStoreOffset = retval.paddingMemoryEnd+pseudoSubgroupInvocation;
286+
287+
return retval;
295288
}
296289

297290
// protected:
298-
uint paddingMemoryEnd;
299-
uint scanStoreOffset;
300-
uint lastLoadOffset;
291+
uint paddingMemoryEnd;
292+
uint scanStoreOffset;
293+
uint lastLoadOffset;
301294
};
302295

303296
template<class Binop, class ScratchAccessor>
304297
struct inclusive_scan : scan_base
305298
{
306299
static inclusive_scan<Binop,ScratchAccessor> create()
307300
{
308-
return scan_base<Binop,ScratchAccessor>::create(); // REVIEW: Is this correct?
301+
return scan_base<Binop,ScratchAccessor>::create(); // REVIEW: Is this correct?
309302
}
310303

311304
template<typename T, bool initializeScratch>
312305
T operator()(T value)
313306
{
314-
ScratchAccessor scratchAccessor;
315-
Binop op;
307+
ScratchAccessor scratchAccessor;
308+
Binop op;
316309

317-
if (initializeScratch)
318-
{
319-
nbl::hlsl::subgroupBarrier();
320-
nbl::hlsl::subgroupMemoryBarrierShared();
321-
scratchAccessor.set(scanStoreOffset ,value);
322-
if (scan_base::pseudoSubgroupInvocation<scan_base::HalfSubgroupSize)
323-
scratchAccessor.set(lastLoadOffset,op::identity());
324-
}
325-
nbl::hlsl::subgroupBarrier();
326-
nbl::hlsl::subgroupMemoryBarrierShared();
327-
// Stone-Kogge adder
328-
// (devsh): it seems that lanes below <HalfSubgroupSize/step are doing useless work,
329-
// but they're SIMD and adding an `if`/conditional execution is more expensive
330-
value = op(value,scratchAccessor.get(scanStoreOffset-1u));
331-
[[unroll]]
332-
for (uint stp=2u; stp<=scan_base::HalfSubgroupSize; stp<<=1u)
333-
{
334-
scratchAccessor.set(scanStoreOffset,value);
335-
nbl::hlsl::subgroupBarrier();
336-
nbl::hlsl::subgroupMemoryBarrierShared();
337-
value = op(value,scratchAccessor.get(scanStoreOffset-stp));
338-
nbl::hlsl::subgroupBarrier();
339-
nbl::hlsl::subgroupMemoryBarrierShared();
340-
}
341-
return value;
310+
if (initializeScratch)
311+
{
312+
nbl::hlsl::subgroup::Barrier();
313+
nbl::hlsl::subgroup::MemoryBarrierShared();
314+
315+
// each invocation initializes its respective slot with its value
316+
scratchAccessor.set(scanStoreOffset ,value);
317+
318+
// additionally, the first half invocations initialize the padding slots
319+
// with identity values
320+
if (scan_base::pseudoSubgroupInvocation<scan_base::HalfSubgroupSize)
321+
scratchAccessor.set(lastLoadOffset,op::identity());
322+
}
323+
nbl::hlsl::subgroup::Barrier();
324+
nbl::hlsl::subgroup::MemoryBarrierShared();
325+
// Stone-Kogge adder
326+
// (devsh): it seems that lanes below <HalfSubgroupSize/step are doing useless work,
327+
// but they're SIMD and adding an `if`/conditional execution is more expensive
328+
#ifdef NBL_GL_KHR_shader_subgroup_shuffle
329+
if(scan_base::pseudoSubgroupInvocation>=1u)
330+
// the first invocation (index 0) in the subgroup doesn't have anything in its left
331+
value = op(value, ShuffleUp(value, 1u));
332+
#else
333+
value = op(value,scratchAccessor.get(scanStoreOffset-1u));
334+
#endif
335+
[[unroll]]
336+
for (uint step=2u; step<=scan_base::HalfSubgroupSize; step<<=1u)
337+
{
338+
#ifdef NBL_GL_KHR_shader_subgroup_shuffle // REVIEW: maybe use it by default?
339+
// there is no scratch and padding entries in this case so we have to guard the shuffles to not go out of bounds
340+
if(scan_base::pseudoSubgroupInvocation>=step)
341+
value = op(value, ShuffleUp(value, step));
342+
#else
343+
scratchAccessor.set(scanStoreOffset,value);
344+
nbl::hlsl::subgroup::Barrier();
345+
nbl::hlsl::subgroup::MemoryBarrierShared();
346+
value = op(value,scratchAccessor.get(scanStoreOffset-step));
347+
nbl::hlsl::subgroup::Barrier();
348+
nbl::hlsl::subgroup::MemoryBarrierShared();
349+
#endif
350+
}
351+
return value;
342352
}
343353

344354
template<typename T>
@@ -361,29 +371,32 @@ struct exclusive_scan
361371
template<typename T, bool initializeScratch>
362372
T operator()(T value)
363373
{
364-
value = impl.operator()<T,initializeScratch>(value);
365-
366-
// store value to smem so we can shuffle it
367-
scratchAccessor.set(impl.scanStoreOffset,value);
368-
nbl::hlsl::subgroupBarrier();
369-
nbl::hlsl::subgroupMemoryBarrierShared();
370-
// get previous item
371-
value = scratchAccessor.get(impl.scanStoreOffset-1u);
372-
nbl::hlsl::subgroupBarrier();
373-
nbl::hlsl::subgroupMemoryBarrierShared();
374-
375-
// return it
376-
return value;
374+
value = impl.operator()<T,initializeScratch>(value);
375+
376+
// store value to smem so we can shuffle it
377+
#ifdef NBL_GL_KHR_shader_subgroup_shuffle // REVIEW: Should we check this or just use shuffle by default?
378+
value = ShuffleUp(value, 1);
379+
#else
380+
scratchAccessor.set(impl.scanStoreOffset,value);
381+
nbl::hlsl::subgroup::Barrier();
382+
nbl::hlsl::subgroup::MemoryBarrierShared();
383+
// get previous item
384+
value = scratchAccessor.get(impl.scanStoreOffset-1u);
385+
nbl::hlsl::subgroup::Barrier();
386+
nbl::hlsl::subgroup::MemoryBarrierShared();
387+
#endif
388+
// return it
389+
return value;
377390
}
378391

379392
template<typename T>
380393
T operator()(const T value)
381394
{
382-
return operator()<T,true>(value);
395+
return operator()<T,true>(value);
383396
}
384-
397+
385398
// protected:
386-
inclusive_scan<Binop,ScratchAccessor> impl;
399+
inclusive_scan<Binop,ScratchAccessor> impl;
387400
};
388401

389402
template<class Binop, class ScratchAccessor>
@@ -399,23 +412,29 @@ struct reduction
399412
template<typename T, bool initializeScratch>
400413
T operator()(T value)
401414
{
402-
value = impl.operator()<T,initializeScratch>(value);
403-
404-
// store value to smem so we can broadcast it to everyone
405-
scratchAccessor.set(impl.scanStoreOffset,value);
406-
nbl::hlsl::subgroupBarrier();
407-
nbl::hlsl::subgroupMemoryBarrierShared();
408-
uint reductionResultOffset = impl.paddingMemoryEnd;
409-
if ((scan_base::LastWorkgroupInvocation>>nbl::hlsl::subgroup::subgroupSizeLog2())!=nbl::hlsl::subgroup::ID())
410-
reductionResultOffset += scan_base::LastWorkgroupInvocation & scan_base::LoMask;
411-
else
412-
reductionResultOffset += scan_base::LoMask;
413-
value = scratchAccessor.get(reductionResultOffset);
414-
nbl::hlsl::subgroupBarrier();
415-
nbl::hlsl::subgroupMemoryBarrierShared();
416-
417-
// return it
418-
return value;
415+
value = impl.operator()<T,initializeScratch>(value);
416+
417+
// in case of multiple subgroups inside the WG
418+
if ((scan_base::LastWorkgroupInvocation>>nbl::hlsl::subgroup::SizeLog2())!=nbl::hlsl::subgroup::InvocationID())
419+
reductionResultOffset += scan_base::LastWorkgroupInvocation & scan_base::LoMask;
420+
else // in case of single subgroup in WG
421+
reductionResultOffset += scan_base::LoMask;
422+
423+
#ifdef NBL_GL_KHR_shader_subgroup_shuffle
424+
Shuffle(value, reductionResultOffset);
425+
#else
426+
// store value to smem so we can broadcast it to everyone
427+
scratchAccessor.set(impl.scanStoreOffset,value);
428+
nbl::hlsl::subgroup::Barrier();
429+
nbl::hlsl::subgroup::MemoryBarrierShared();
430+
uint reductionResultOffset = impl.paddingMemoryEnd;
431+
432+
value = scratchAccessor.get(reductionResultOffset);
433+
nbl::hlsl::subgroup::Barrier();
434+
nbl::hlsl::subgroup::MemoryBarrierShared();
435+
#endif
436+
// return it
437+
return value;
419438
}
420439

421440
template<typename T>
@@ -433,4 +452,6 @@ struct reduction
433452
}
434453
}
435454

455+
#undef WHOLE_WAVE
456+
436457
#endif

0 commit comments

Comments
 (0)