Skip to content

Commit a5360ca

Browse files
committed
Refactor following PR review
1 parent cf2f747 commit a5360ca

File tree

3 files changed

+42
-103
lines changed

3 files changed

+42
-103
lines changed

include/nbl/builtin/hlsl/subgroup/fft.hlsl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ namespace hlsl
1111
{
1212
namespace subgroup
1313
{
14-
namespace fft
15-
{
1614

1715
// -----------------------------------------------------------------------------------------------------------------------------------------------------------------
1816
template<bool Inverse, typename Scalar, class device_capabilities=void>
@@ -42,15 +40,15 @@ struct FFT<false, Scalar, device_capabilities>
4240
hi.imag(exchanged.y);
4341
}
4442
// Get twiddle with k = subgroupInvocation mod stride, halfN = stride
45-
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi);
43+
fft::DIF<Scalar>::radix2(fft::twiddle<false, Scalar>(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi);
4644
}
4745

4846
static void __call(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
4947
{
5048
const uint32_t subgroupSize = glsl::gl_SubgroupSize(); //This is N/2
5149

5250
// special first iteration
53-
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi);
51+
fft::DIF<Scalar>::radix2(fft::twiddle<false, Scalar>(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi);
5452

5553
// Decimation in Frequency
5654
for (uint32_t stride = subgroupSize >> 1; stride > 0; stride >>= 1)
@@ -67,7 +65,7 @@ struct FFT<true, Scalar, device_capabilities>
6765
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
6866
{
6967
// Get twiddle with k = subgroupInvocation mod stride, halfN = stride
70-
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi);
68+
fft::DIT<Scalar>::radix2(fft::twiddle<true, Scalar>(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi);
7169

7270
const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & stride);
7371
const vector <Scalar, 2> toTrade = topHalf ? vector <Scalar, 2>(lo.real(), lo.imag()) : vector <Scalar, 2>(hi.real(), hi.imag());
@@ -94,7 +92,7 @@ struct FFT<true, Scalar, device_capabilities>
9492
FFT_loop(stride, lo, hi);
9593

9694
// special last iteration
97-
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi);
95+
fft::DIT<Scalar>::radix2(fft::twiddle<true, Scalar>(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi);
9896
divides_assign< complex_t<Scalar> > divAss;
9997
divAss(lo, doubleSubgroupSize);
10098
divAss(hi, doubleSubgroupSize);
@@ -105,6 +103,5 @@ struct FFT<true, Scalar, device_capabilities>
105103
}
106104
}
107105
}
108-
}
109106

110107
#endif

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 32 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "nbl/builtin/hlsl/subgroup/fft.hlsl"
55
#include "nbl/builtin/hlsl/workgroup/basic.hlsl"
66
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
7-
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
87
#include "nbl/builtin/hlsl/workgroup/shuffle.hlsl"
98

109
namespace nbl
@@ -18,43 +17,41 @@ namespace fft
1817
// ---------------------------------- Utils -----------------------------------------------
1918

2019
template<typename SharedMemoryAccessor, typename Scalar>
21-
void exchangeValues(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
20+
void exchangeValues(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
2221
{
2322
const bool topHalf = bool(threadID & stride);
23+
// Ternary won't take structs so we use this aux variable
2424
vector <Scalar, 2> toExchange = topHalf ? vector <Scalar, 2>(lo.real(), lo.imag()) : vector <Scalar, 2>(hi.real(), hi.imag());
25-
shuffleXor<SharedMemoryAccessor, vector <Scalar, 2> >::__call(toExchange, stride, sharedmemAdaptor);
25+
complex_t<Scalar> toExchangeComplex = {toExchange.x, toExchange.y};
26+
shuffleXor<SharedMemoryAccessor, complex_t<Scalar> >::__call(toExchangeComplex, stride, sharedmemAccessor);
2627
if (topHalf)
27-
{
28-
lo.real(toExchange.x);
29-
lo.imag(toExchange.y);
30-
}
28+
lo = toExchangeComplex;
3129
else
32-
{
33-
hi.real(toExchange.x);
34-
hi.imag(toExchange.y);
35-
}
30+
hi = toExchangeComplex;
3631
}
3732

33+
} //namespace fft
34+
3835
// ----------------------------------- End Utils -----------------------------------------------
3936

4037
template<uint16_t ElementsPerInvocation, bool Inverse, typename Scalar, class device_capabilities=void>
4138
struct FFT;
4239

4340
// For the FFT methods below, we assume:
4441
// - Accessor is a global memory accessor to an array fitting 2 * _NBL_HLSL_WORKGROUP_SIZE_ elements of type complex_t<Scalar>, used to get inputs / set outputs of the FFT,
45-
// that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread. The data layout is assumed to be a whole array of real parts
46-
// followed by a whole array of imaginary parts. So it would be something like
47-
// [x_0, x_1, ..., x_{2 * _NBL_HLSL_WORKGROUP_SIZE_}, y_0, y_1, ..., y_{2 * _NBL_HLSL_WORKGROUP_SIZE_}]
48-
// - SharedMemoryAccessor accesses a shared memory array that can fit _NBL_HLSL_WORKGROUP_SIZE_ elements of type complex_t<Scalar>, so 2 * _NBL_HLSL_WORKGROUP_SIZE_ Scalars
42+
// that is, one "lo" and one "hi" complex numbers per thread, essentially 4 Scalars per thread.
43+
// There are no assumptions on the data layout: we just require the accessor to provide get and set methods for complex_t<Scalar>.
44+
// - SharedMemoryAccessor accesses a shared memory array that can fit _NBL_HLSL_WORKGROUP_SIZE_ elements of type complex_t<Scalar>, with get and set
45+
// methods for complex_t<Scalar>. It benefits from coalesced accesses
4946

5047
// 2 items per invocation forward specialization
5148
template<typename Scalar, class device_capabilities>
5249
struct FFT<2,false, Scalar, device_capabilities>
5350
{
5451
template<typename SharedMemoryAccessor>
55-
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
52+
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
5653
{
57-
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAdaptor);
54+
fft::exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAccessor);
5855

5956
// Get twiddle with k = threadID mod stride, halfN = stride
6057
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(threadID & (stride - 1), stride), lo, hi);
@@ -64,25 +61,14 @@ struct FFT<2,false, Scalar, device_capabilities>
6461
template<typename Accessor, typename SharedMemoryAccessor>
6562
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
6663
{
67-
// Set up the MemAdaptors
68-
MemoryAdaptor<Accessor, 1> memAdaptor;
69-
memAdaptor.accessor = accessor;
70-
MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
71-
sharedmemAdaptor.accessor = sharedmemAccessor;
72-
7364
// Compute the indices only once
7465
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
7566
const uint32_t loIx = threadID;
7667
const uint32_t hiIx = loIx + _NBL_HLSL_WORKGROUP_SIZE_;
7768

7869
// Read lo, hi values from global memory
79-
vector <Scalar, 2> loVec;
80-
vector <Scalar, 2> hiVec;
81-
// TODO: if we get rid of the Memory Adaptor on the accessor and require comples getters and setters, then no `2*`
82-
memAdaptor.get(2 * loIx , loVec);
83-
memAdaptor.get(2 * hiIx, hiVec);
84-
complex_t<Scalar> lo = {loVec.x, loVec.y};
85-
complex_t<Scalar> hi = {hiVec.x, hiVec.y};
70+
complex_t<Scalar> lo = accessor.get(loIx);
71+
complex_t<Scalar> hi = accessor.get(hiIx);
8672

8773
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
8874
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
@@ -93,27 +79,20 @@ struct FFT<2,false, Scalar, device_capabilities>
9379
// Run bigger steps until Subgroup-sized
9480
for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1; stride > glsl::gl_SubgroupSize(); stride >>= 1)
9581
{
96-
FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor);
97-
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
82+
FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAccessor);
83+
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
9884
}
9985

10086
// special last workgroup-shuffle
101-
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
87+
fft::exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAccessor);
10288
}
10389

10490
// Subgroup-sized FFT
105-
subgroup::fft::FFT<false, Scalar, device_capabilities>::__call(lo, hi);
91+
subgroup::FFT<false, Scalar, device_capabilities>::__call(lo, hi);
10692

10793
// Put values back in global mem
108-
loVec = vector <Scalar, 2>(lo.real(), lo.imag());
109-
hiVec = vector <Scalar, 2>(hi.real(), hi.imag());
110-
111-
memAdaptor.set(2 * loIx, loVec);
112-
memAdaptor.set(2 * hiIx, hiVec);
113-
114-
// Update state for accessors
115-
accessor = memAdaptor.accessor;
116-
sharedmemAccessor = sharedmemAdaptor.accessor;
94+
accessor.set(loIx, lo);
95+
accessor.set(hiIx, hi);
11796
}
11897
};
11998

@@ -124,53 +103,42 @@ template<typename Scalar, class device_capabilities>
124103
struct FFT<2,true, Scalar, device_capabilities>
125104
{
126105
template<typename SharedMemoryAccessor>
127-
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
106+
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
128107
{
129108
// Get twiddle with k = threadID mod stride, halfN = stride
130109
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(threadID & (stride - 1), stride), lo, hi);
131110

132-
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAdaptor);
111+
fft::exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAccessor);
133112
}
134113

135114

136115
template<typename Accessor, typename SharedMemoryAccessor>
137116
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
138117
{
139-
// Set up the MemAdaptors
140-
MemoryAdaptor<Accessor, 1> memAdaptor;
141-
memAdaptor.accessor = accessor;
142-
MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
143-
sharedmemAdaptor.accessor = sharedmemAccessor;
144-
145118
// Compute the indices only once
146119
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
147120
const uint32_t loIx = (glsl::gl_SubgroupID()<<(glsl::gl_SubgroupSizeLog2()+1))+glsl::gl_SubgroupInvocationID();
148121
const uint32_t hiIx = loIx+glsl::gl_SubgroupSize();
149122

150123
// Read lo, hi values from global memory
151-
vector <Scalar, 2> loVec;
152-
vector <Scalar, 2> hiVec;
153-
memAdaptor.get(2 * loIx , loVec);
154-
memAdaptor.get(2 * hiIx, hiVec);
155-
complex_t<Scalar> lo = {loVec.x, loVec.y};
156-
complex_t<Scalar> hi = {hiVec.x, hiVec.y};
124+
complex_t<Scalar> lo = accessor.get(loIx);
125+
complex_t<Scalar> hi = accessor.get(hiIx);
157126

158127
// Run a subgroup-sized FFT, then continue with bigger steps
159-
subgroup::fft::FFT<true, Scalar, device_capabilities>::__call(lo, hi);
128+
subgroup::FFT<true, Scalar, device_capabilities>::__call(lo, hi);
160129

161130
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
162-
163131
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
164132
{
165133
// special first workgroup-shuffle
166-
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
134+
fft::exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAccessor);
167135

168136
// The bigger steps
169137
for (uint32_t stride = glsl::gl_SubgroupSize() << 1; stride < _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1)
170138
{
171139
// Order of waiting for shared mem writes is also reversed here, since the shuffle came earlier
172-
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
173-
FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor);
140+
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
141+
FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAccessor);
174142
}
175143

176144
// special last iteration
@@ -181,14 +149,8 @@ struct FFT<2,true, Scalar, device_capabilities>
181149
}
182150

183151
// Put values back in global mem
184-
loVec = vector <Scalar, 2>(lo.real(), lo.imag());
185-
hiVec = vector <Scalar, 2>(hi.real(), hi.imag());
186-
memAdaptor.set(2 * loIx, loVec);
187-
memAdaptor.set(2 * hiIx, hiVec);
188-
189-
// Update state for accessors
190-
accessor = memAdaptor.accessor;
191-
sharedmemAccessor = sharedmemAdaptor.accessor;
152+
accessor.set(loIx, lo);
153+
accessor.set(hiIx, hi);
192154
}
193155
};
194156

@@ -221,6 +183,5 @@ struct FFT
221183
}
222184
}
223185
}
224-
}
225186

226187
#endif

include/nbl/builtin/hlsl/workgroup/shuffle.hlsl

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@
44
#include "nbl/builtin/hlsl/memory_accessor.hlsl"
55

66
// TODO: Add other shuffles
7-
// TODO: Consider adding an enable_if or static assert that 1 <= N <= 4 and that Scalar is a proper scalar type
8-
// TODO: Consider adding version that doesn't take a precomputed threadID and instead calls workgroup::SubgroupContiguousIndex
97

10-
// Unlike subgroups we pass a precomputed threadID so we don't go around recomputing it every time
118
// We assume the accessor in the adaptor is clean and unaliased when calling this function, but we don't enforce this after the shuffle
129

1310
namespace nbl
@@ -17,41 +14,25 @@ namespace hlsl
1714
namespace workgroup
1815
{
1916

20-
2117
template<typename SharedMemoryAccessor, typename T>
2218
struct shuffleXor
2319
{
24-
static void __call(NBL_REF_ARG(T) value, uint32_t mask, uint32_t threadID, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
20+
static void __call(NBL_REF_ARG(T) value, uint32_t mask, uint32_t threadID, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
2521
{
26-
sharedmemAdaptor.set(threadID, value);
22+
sharedmemAccessor.set(threadID, value);
2723

2824
// Wait until all writes are done before reading
29-
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
25+
sharedmemAccessor.workgroupExecutionAndMemoryBarrier();
3026

31-
sharedmemAdaptor.get(threadID ^ mask, value);
27+
value = sharedmemAccessor.get(threadID ^ mask);
3228
}
3329

34-
static void __call(NBL_REF_ARG(T) value, uint32_t mask, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
30+
static void __call(NBL_REF_ARG(T) value, uint32_t mask, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
3531
{
36-
__call(value, mask, uint32_t(SubgroupContiguousIndex()), sharedmemAdaptor);
32+
__call(value, mask, uint32_t(SubgroupContiguousIndex()), sharedmemAccessor);
3733
}
3834
};
3935

40-
/*
41-
42-
template<typename SharedMemoryAccessor, typename T>
43-
void shuffleXor(NBL_REF_ARG(T) value, uint32_t mask, uint32_t threadID, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
44-
{
45-
sharedmemAdaptor.set(threadID, value);
46-
47-
// Wait until all writes are done before reading
48-
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
49-
50-
sharedmemAdaptor.get(threadID ^ mask, value);
51-
}
52-
53-
*/
54-
5536
}
5637
}
5738
}

0 commit comments

Comments
 (0)