Skip to content

Commit e98ef60

Browse files
committed
2-element per invocation workgroup FFT working fine
1 parent 188e63e commit e98ef60

File tree

4 files changed

+84
-83
lines changed

4 files changed

+84
-83
lines changed

3rdparty/dxc/dxc

Submodule dxc updated 257 files

include/nbl/builtin/hlsl/subgroup/fft.hlsl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ namespace hlsl
1111
{
1212
namespace subgroup
1313
{
14+
namespace fft
15+
{
1416

1517
// -----------------------------------------------------------------------------------------------------------------------------------------------------------------
1618
template<bool Inverse, typename Scalar, class device_capabilities=void>
@@ -40,15 +42,15 @@ struct FFT<false, Scalar, device_capabilities>
4042
hi.imag(exchanged.y);
4143
}
4244
// Get twiddle with k = subgroupInvocation mod stride, halfN = stride
43-
fft::DIF<Scalar>::radix2(fft::twiddle<false, Scalar>(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi);
45+
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi);
4446
}
4547

4648
static void __call(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
4749
{
4850
const uint32_t subgroupSize = glsl::gl_SubgroupSize(); //This is N/2
4951

5052
// special first iteration
51-
fft::DIF<Scalar>::radix2(fft::twiddle<false, Scalar>(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi);
53+
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi);
5254

5355
// Decimation in Frequency
5456
for (uint32_t stride = subgroupSize >> 1; stride > 0; stride >>= 1)
@@ -65,7 +67,7 @@ struct FFT<true, Scalar, device_capabilities>
6567
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi)
6668
{
6769
// Get twiddle with k = subgroupInvocation mod stride, halfN = stride
68-
fft::DIT<Scalar>::radix2(fft::twiddle<true, Scalar>(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi);
70+
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(glsl::gl_SubgroupInvocationID() & (stride - 1), stride), lo, hi);
6971

7072
const bool topHalf = bool(glsl::gl_SubgroupInvocationID() & stride);
7173
const vector <Scalar, 2> toTrade = topHalf ? vector <Scalar, 2>(lo.real(), lo.imag()) : vector <Scalar, 2>(hi.real(), hi.imag());
@@ -92,7 +94,7 @@ struct FFT<true, Scalar, device_capabilities>
9294
FFT_loop(stride, lo, hi);
9395

9496
// special last iteration
95-
fft::DIT<Scalar>::radix2(fft::twiddle<true, Scalar>(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi);
97+
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(glsl::gl_SubgroupInvocationID(), subgroupSize), lo, hi);
9698
divides_assign< complex_t<Scalar> > divAss;
9799
divAss(lo, doubleSubgroupSize);
98100
divAss(hi, doubleSubgroupSize);
@@ -103,5 +105,6 @@ struct FFT<true, Scalar, device_capabilities>
103105
}
104106
}
105107
}
108+
}
106109

107110
#endif

include/nbl/builtin/hlsl/workgroup/fft.hlsl

Lines changed: 52 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,18 @@ namespace nbl
1111
{
1212
namespace hlsl
1313
{
14-
15-
namespace glsl
16-
{
17-
18-
// Define this method from glsl_compat/core.hlsl
19-
uint32_t3 gl_WorkGroupSize() {
20-
return uint32_t3(_NBL_HLSL_WORKGROUP_SIZE_, 1, 1);
21-
}
22-
23-
} //namespace glsl
24-
2514
namespace workgroup
2615
{
27-
16+
namespace fft
17+
{
2818
// ---------------------------------- Utils -----------------------------------------------
2919

3020
template<typename SharedMemoryAccessor, typename Scalar>
3121
void exchangeValues(NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
3222
{
3323
const bool topHalf = bool(threadID & stride);
3424
vector <Scalar, 2> toExchange = topHalf ? vector <Scalar, 2>(lo.real(), lo.imag()) : vector <Scalar, 2>(hi.real(), hi.imag());
35-
shuffleXor<SharedMemoryAccessor, Scalar, 2>(toExchange, stride, threadID, sharedmemAdaptor);
25+
shuffleXor<SharedMemoryAccessor, vector <Scalar, 2> >::__call(toExchange, stride, sharedmemAdaptor);
3626
if (topHalf)
3727
{
3828
lo.real(toExchange.x);
@@ -67,60 +57,59 @@ struct FFT<2,false, Scalar, device_capabilities>
6757
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAdaptor);
6858

6959
// Get twiddle with k = threadID mod stride, halfN = stride
70-
fft::DIF<Scalar>::radix2(fft::twiddle<false, Scalar>(threadID & (stride - 1), stride), lo, hi);
60+
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(threadID & (stride - 1), stride), lo, hi);
7161
}
7262

7363

7464
template<typename Accessor, typename SharedMemoryAccessor>
7565
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
7666
{
7767
// Set up the MemAdaptors
78-
MemoryAdaptor<Accessor, _NBL_HLSL_WORKGROUP_SIZE_ << 1> memAdaptor;
68+
MemoryAdaptor<Accessor, 1> memAdaptor;
7969
memAdaptor.accessor = accessor;
8070
MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
8171
sharedmemAdaptor.accessor = sharedmemAccessor;
8272

83-
// Compute the SubgroupContiguousIndex only once
73+
// Compute the indices only once
8474
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
75+
const uint32_t loIx = threadID;
76+
const uint32_t hiIx = loIx + _NBL_HLSL_WORKGROUP_SIZE_;
8577

8678
// Read lo, hi values from global memory
8779
vector <Scalar, 2> loVec;
8880
vector <Scalar, 2> hiVec;
89-
memAdaptor.get(threadID, loVec);
90-
memAdaptor.get(threadID + _NBL_HLSL_WORKGROUP_SIZE_, hiVec);
81+
// TODO: if we get rid of the Memory Adaptor on the accessor and require comples getters and setters, then no `2*`
82+
memAdaptor.get(2 * loIx , loVec);
83+
memAdaptor.get(2 * hiIx, hiVec);
9184
complex_t<Scalar> lo = {loVec.x, loVec.y};
9285
complex_t<Scalar> hi = {hiVec.x, hiVec.y};
9386

94-
// special first iteration - only if workgroupsize > subgroupsize
87+
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
9588
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
96-
fft::DIF<Scalar>::radix2(fft::twiddle<false, Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
89+
{
90+
// special first iteration
91+
hlsl::fft::DIF<Scalar>::radix2(hlsl::fft::twiddle<false, Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
9792

98-
// Run bigger steps until Subgroup-sized
99-
for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1; stride > glsl::gl_SubgroupSize(); stride >>= 1)
100-
{
101-
// If at least one loop was executed, we must wait for all threads to get their values before we write to shared mem again
102-
if ( !(stride & (_NBL_HLSL_WORKGROUP_SIZE_ >> 1)) )
93+
// Run bigger steps until Subgroup-sized
94+
for (uint32_t stride = _NBL_HLSL_WORKGROUP_SIZE_ >> 1; stride > glsl::gl_SubgroupSize(); stride >>= 1)
95+
{
96+
FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor);
10397
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
104-
FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor);
105-
}
98+
}
10699

107-
// special last workgroup-shuffle - only if workgroupsize > subgroupsize
108-
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
109-
{
110-
// Wait for all threads to be done with reads in the last loop before writing to shared mem
111-
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
112-
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
113-
}
100+
// special last workgroup-shuffle
101+
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
102+
}
114103

115104
// Subgroup-sized FFT
116-
subgroup::FFT<false, Scalar, device_capabilities>::__call(lo, hi);
105+
subgroup::fft::FFT<false, Scalar, device_capabilities>::__call(lo, hi);
117106

118107
// Put values back in global mem
119108
loVec = vector <Scalar, 2>(lo.real(), lo.imag());
120109
hiVec = vector <Scalar, 2>(hi.real(), hi.imag());
121110

122-
memAdaptor.set(threadID, loVec);
123-
memAdaptor.set(threadID + _NBL_HLSL_WORKGROUP_SIZE_, hiVec);
111+
memAdaptor.set(2 * loIx, loVec);
112+
memAdaptor.set(2 * hiIx, hiVec);
124113

125114
// Update state for accessors
126115
accessor = memAdaptor.accessor;
@@ -138,7 +127,7 @@ struct FFT<2,true, Scalar, device_capabilities>
138127
static void FFT_loop(uint32_t stride, NBL_REF_ARG(complex_t<Scalar>) lo, NBL_REF_ARG(complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
139128
{
140129
// Get twiddle with k = threadID mod stride, halfN = stride
141-
fft::DIF<Scalar>::radix2(fft::twiddle<true, Scalar>(threadID & (stride - 1), stride), lo, hi);
130+
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(threadID & (stride - 1), stride), lo, hi);
142131

143132
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, stride, sharedmemAdaptor);
144133
}
@@ -148,53 +137,54 @@ struct FFT<2,true, Scalar, device_capabilities>
148137
static void __call(NBL_REF_ARG(Accessor) accessor, NBL_REF_ARG(SharedMemoryAccessor) sharedmemAccessor)
149138
{
150139
// Set up the MemAdaptors
151-
MemoryAdaptor<Accessor, _NBL_HLSL_WORKGROUP_SIZE_ << 1> memAdaptor;
140+
MemoryAdaptor<Accessor, 1> memAdaptor;
152141
memAdaptor.accessor = accessor;
153142
MemoryAdaptor<SharedMemoryAccessor> sharedmemAdaptor;
154143
sharedmemAdaptor.accessor = sharedmemAccessor;
155144

156-
// Compute the SubgroupContiguousIndex only once
145+
// Compute the indices only once
157146
const uint32_t threadID = uint32_t(SubgroupContiguousIndex());
147+
const uint32_t loIx = (glsl::gl_SubgroupID()<<(glsl::gl_SubgroupSizeLog2()+1))+glsl::gl_SubgroupInvocationID();
148+
const uint32_t hiIx = loIx+glsl::gl_SubgroupSize();
158149

159150
// Read lo, hi values from global memory
160151
vector <Scalar, 2> loVec;
161152
vector <Scalar, 2> hiVec;
162-
memAdaptor.get(threadID, loVec);
163-
memAdaptor.get(threadID + _NBL_HLSL_WORKGROUP_SIZE_, hiVec);
153+
memAdaptor.get(2 * loIx , loVec);
154+
memAdaptor.get(2 * hiIx, hiVec);
164155
complex_t<Scalar> lo = {loVec.x, loVec.y};
165156
complex_t<Scalar> hi = {hiVec.x, hiVec.y};
166157

167158
// Run a subgroup-sized FFT, then continue with bigger steps
168-
subgroup::FFT<true, Scalar, device_capabilities>::__call(lo, hi);
159+
subgroup::fft::FFT<true, Scalar, device_capabilities>::__call(lo, hi);
169160

170-
// special first workgroup-shuffle - only if workgroupsize > subgroupsize
161+
// If for some reason you're running a small FFT, skip all the bigger-than-subgroup steps
162+
171163
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
172164
{
165+
// special first workgroup-shuffle
173166
exchangeValues<SharedMemoryAccessor, Scalar>(lo, hi, threadID, glsl::gl_SubgroupSize(), sharedmemAdaptor);
174-
}
175-
176-
// The bigger steps
177-
for (uint32_t stride = glsl::gl_SubgroupSize() << 1; stride < _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1)
178-
{
179-
// If we enter this for loop, then the special first workgroup shuffle went through, so wait on that
180-
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
181-
FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor);
182-
}
167+
168+
// The bigger steps
169+
for (uint32_t stride = glsl::gl_SubgroupSize() << 1; stride < _NBL_HLSL_WORKGROUP_SIZE_; stride <<= 1)
170+
{
171+
// Order of waiting for shared mem writes is also reversed here, since the shuffle came earlier
172+
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
173+
FFT_loop<SharedMemoryAccessor>(stride, lo, hi, threadID, sharedmemAdaptor);
174+
}
183175

184-
// special last iteration - only if workgroupsize > subgroupsize
185-
if (_NBL_HLSL_WORKGROUP_SIZE_ > glsl::gl_SubgroupSize())
186-
{
187-
fft::DIT<Scalar>::radix2(fft::twiddle<true, Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
176+
// special last iteration
177+
hlsl::fft::DIT<Scalar>::radix2(hlsl::fft::twiddle<true, Scalar>(threadID, _NBL_HLSL_WORKGROUP_SIZE_), lo, hi);
188178
divides_assign< complex_t<Scalar> > divAss;
189179
divAss(lo, _NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize());
190-
divAss(hi, _NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize());
191-
}
180+
divAss(hi, _NBL_HLSL_WORKGROUP_SIZE_ / glsl::gl_SubgroupSize());
181+
}
192182

193183
// Put values back in global mem
194184
loVec = vector <Scalar, 2>(lo.real(), lo.imag());
195185
hiVec = vector <Scalar, 2>(hi.real(), hi.imag());
196-
memAdaptor.set(threadID, loVec);
197-
memAdaptor.set(threadID + _NBL_HLSL_WORKGROUP_SIZE_, hiVec);
186+
memAdaptor.set(2 * loIx, loVec);
187+
memAdaptor.set(2 * hiIx, hiVec);
198188

199189
// Update state for accessors
200190
accessor = memAdaptor.accessor;
@@ -203,21 +193,6 @@ struct FFT<2,true, Scalar, device_capabilities>
203193
};
204194

205195

206-
207-
208-
209-
210-
211-
212-
213-
214-
215-
216-
217-
218-
219-
220-
221196
// ---------------------------- Below pending --------------------------------------------------
222197

223198
/*
@@ -246,5 +221,6 @@ struct FFT
246221
}
247222
}
248223
}
224+
}
249225

250226
#endif

include/nbl/builtin/hlsl/workgroup/shuffle.hlsl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,29 @@ namespace workgroup
1818
{
1919

2020

21-
template<typename SharedMemoryAccessor, typename Scalar, uint32_t N = 1>
22-
void shuffleXor(NBL_REF_ARG(vector <Scalar, N>) value, uint32_t mask, uint32_t threadID, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
21+
template<typename SharedMemoryAccessor, typename T>
22+
struct shuffleXor
23+
{
24+
static void __call(NBL_REF_ARG(T) value, uint32_t mask, uint32_t threadID, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
25+
{
26+
sharedmemAdaptor.set(threadID, value);
27+
28+
// Wait until all writes are done before reading
29+
sharedmemAdaptor.workgroupExecutionAndMemoryBarrier();
30+
31+
sharedmemAdaptor.get(threadID ^ mask, value);
32+
}
33+
34+
static void __call(NBL_REF_ARG(T) value, uint32_t mask, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
35+
{
36+
__call(value, mask, uint32_t(SubgroupContiguousIndex()), sharedmemAdaptor);
37+
}
38+
};
39+
40+
/*
41+
42+
template<typename SharedMemoryAccessor, typename T>
43+
void shuffleXor(NBL_REF_ARG(T) value, uint32_t mask, uint32_t threadID, NBL_REF_ARG(MemoryAdaptor<SharedMemoryAccessor>) sharedmemAdaptor)
2344
{
2445
sharedmemAdaptor.set(threadID, value);
2546
@@ -29,6 +50,7 @@ void shuffleXor(NBL_REF_ARG(vector <Scalar, N>) value, uint32_t mask, uint32_t t
2950
sharedmemAdaptor.get(threadID ^ mask, value);
3051
}
3152
53+
*/
3254

3355
}
3456
}

0 commit comments

Comments
 (0)