Skip to content

Commit 8861228

Browse files
committed
fft algorithm improvement (use half the threads and compute the full butterfly)
1 parent 44af96e commit 8861228

File tree

1 file changed

+191
-24
lines changed
  • include/nbl/builtin/glsl/ext/FFT

1 file changed

+191
-24
lines changed

include/nbl/builtin/glsl/ext/FFT/fft.glsl

Lines changed: 191 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
#error "_NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_ should be defined."
2626
#endif
2727

28-
// TODO: Investigate why +1 solves all glitches
29-
#define _NBL_GLSL_EXT_FFT_SHARED_SIZE_NEEDED_ _NBL_GLSL_EXT_FFT_MAX_DIM_SIZE_ + 1
28+
#define _NBL_GLSL_EXT_FFT_SHARED_SIZE_NEEDED_ _NBL_GLSL_EXT_FFT_MAX_DIM_SIZE_
3029

3130
#ifdef _NBL_GLSL_SCRATCH_SHARED_DEFINED_
3231
#if NBL_GLSL_LESS(_NBL_GLSL_SCRATCH_SHARED_SIZE_DEFINED_,_NBL_GLSL_EXT_FFT_SHARED_SIZE_NEEDED_)
@@ -103,24 +102,6 @@ uint nbl_glsl_ext_FFT_Parameters_t_getPaddingType() {
103102
return (params.dimension.w) & 0x000000ff;
104103
}
105104

106-
uint nbl_glsl_ext_FFT_calculateTwiddlePower(in uint threadId, in uint iteration, in uint logTwoN)
107-
{
108-
const uint shiftSuffix = logTwoN - 1u - iteration; // can we assert that iteration<logTwoN always?? yes
109-
const uint suffixMask = (2u << iteration) - 1u;
110-
return (threadId & suffixMask) << shiftSuffix;
111-
}
112-
113-
nbl_glsl_complex nbl_glsl_ext_FFT_twiddle(in uint threadId, in uint iteration, in uint logTwoN)
114-
{
115-
uint k = nbl_glsl_ext_FFT_calculateTwiddlePower(threadId, iteration, logTwoN);
116-
return nbl_glsl_expImaginary(-1.0f * 2.0f * nbl_glsl_PI * float(k) / float(1 << logTwoN));
117-
}
118-
119-
nbl_glsl_complex nbl_glsl_ext_FFT_twiddleInverse(in uint threadId, in uint iteration, in uint logTwoN)
120-
{
121-
return nbl_glsl_complex_conjugate(nbl_glsl_ext_FFT_twiddle(threadId, iteration, logTwoN));
122-
}
123-
124105
uint nbl_glsl_ext_FFT_getChannel()
125106
{
126107
uint direction = nbl_glsl_ext_FFT_Parameters_t_getDirection();
@@ -150,7 +131,190 @@ uint nbl_glsl_ext_FFT_getDimLength(uvec3 dimension)
150131
return dimension[direction];
151132
}
152133

134+
uint nbl_glsl_ext_FFT_calculateTwiddlePower(in uint threadId, in uint iteration, in uint logTwoN)
135+
{
136+
const uint shiftSuffix = logTwoN - 1u - iteration;
137+
const uint suffixMask = (1u << iteration) - 1u;
138+
return (threadId & suffixMask) << shiftSuffix;
139+
}
140+
141+
nbl_glsl_complex nbl_glsl_ext_FFT_twiddle(in uint threadId, in uint iteration, in uint logTwoN)
142+
{
143+
uint k = nbl_glsl_ext_FFT_calculateTwiddlePower(threadId, iteration, logTwoN);
144+
return nbl_glsl_expImaginary(-1.0f * 2.0f * nbl_glsl_PI * float(k) / float(1 << logTwoN));
145+
}
146+
147+
nbl_glsl_complex nbl_glsl_ext_FFT_twiddleInverse(in uint threadId, in uint iteration, in uint logTwoN)
148+
{
149+
return nbl_glsl_complex_conjugate(nbl_glsl_ext_FFT_twiddle(threadId, iteration, logTwoN));
150+
}
151+
152+
uint nbl_glsl_ext_FFT_getEvenIndex(in uint threadId, in uint iteration, in uint N) {
153+
return ((threadId & (N - (1u << iteration))) << 1u) | (threadId & ((1u << iteration) - 1u));
154+
}
155+
153156
void nbl_glsl_ext_FFT(bool is_inverse)
157+
{
158+
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
159+
// Virtual Threads Calculation
160+
uint dataLength = nbl_glsl_ext_FFT_getDimLength(nbl_glsl_ext_FFT_Parameters_t_getPaddedDimensions());
161+
uint num_virtual_threads = ((dataLength >> 1)-1u)/(_NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_)+1u;
162+
uint thread_offset = gl_LocalInvocationIndex;
163+
164+
uint channel = nbl_glsl_ext_FFT_getChannel();
165+
166+
// Pass 0: Bit Reversal
167+
uint leadingZeroes = nbl_glsl_clz(dataLength) + 1u;
168+
uint logTwo = 32u - leadingZeroes;
169+
170+
nbl_glsl_complex even_values[_NBL_GLSL_EXT_FFT_MAX_ITEMS_PER_THREAD]; // should be half the prev version
171+
nbl_glsl_complex odd_values[_NBL_GLSL_EXT_FFT_MAX_ITEMS_PER_THREAD];
172+
173+
// Load Initial Values into Local Mem (bit reversed indices)
174+
for(uint t = 0u; t < num_virtual_threads; t++)
175+
{
176+
// TODO: read coords and shuffle with shared memory (scattered access of fast memory)
177+
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
178+
179+
uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, 0, dataLength); // same as tid * 2
180+
181+
uvec3 coords_e = nbl_glsl_ext_FFT_getCoordinates(even_index);
182+
uvec3 bitReversedCoords_e = nbl_glsl_ext_FFT_getBitReversedCoordinates(coords_e, leadingZeroes);
183+
even_values[t] = nbl_glsl_ext_FFT_getPaddedData(bitReversedCoords_e, channel);
184+
185+
uvec3 coords_o = nbl_glsl_ext_FFT_getCoordinates(even_index + 1);
186+
uvec3 bitReversedCoords_o = nbl_glsl_ext_FFT_getBitReversedCoordinates(coords_o, leadingZeroes);
187+
odd_values[t] = nbl_glsl_ext_FFT_getPaddedData(bitReversedCoords_o, channel);
188+
}
189+
190+
// For loop for each stage of the FFT (each virtual thread computes 1 buttefly)
191+
for(uint i = 0u; i < logTwo; ++i)
192+
{
193+
// Computation of each virtual thread
194+
for(uint t = 0u; t < num_virtual_threads; t++)
195+
{
196+
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
197+
nbl_glsl_complex even_value = even_values[t];
198+
nbl_glsl_complex odd_value = odd_values[t];
199+
200+
nbl_glsl_complex twiddle = (!is_inverse)
201+
? nbl_glsl_ext_FFT_twiddle(tid, i, logTwo)
202+
: nbl_glsl_ext_FFT_twiddleInverse(tid, i, logTwo);
203+
204+
nbl_glsl_complex cmplx_mul = nbl_glsl_complex_mul(twiddle, odd_value);
205+
206+
even_values[t] = even_value + cmplx_mul;
207+
odd_values[t] = even_value - cmplx_mul;
208+
}
209+
210+
// Exchange Even and Odd Values with Other Threads (or maybe this thread)
211+
if(i < logTwo - 1)
212+
{
213+
// Get Even/Odd Values X for virtual threads
214+
for(uint t = 0u; t < num_virtual_threads; t++)
215+
{
216+
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
217+
218+
uint stage = i;
219+
uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, stage, dataLength);
220+
uint odd_index = even_index + (1u << stage);
221+
222+
_NBL_GLSL_SCRATCH_SHARED_DEFINED_[even_index] = floatBitsToUint(even_values[t].x);
223+
_NBL_GLSL_SCRATCH_SHARED_DEFINED_[odd_index] = floatBitsToUint(odd_values[t].x);
224+
}
225+
226+
barrier();
227+
memoryBarrierShared();
228+
229+
for(uint t = 0u; t < num_virtual_threads; t++)
230+
{
231+
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
232+
233+
uint stage = i + 1u;
234+
uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, stage, dataLength);
235+
uint odd_index = even_index + (1u << stage);
236+
237+
even_values[t].x = uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[even_index]);
238+
odd_values[t].x = uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[odd_index]);
239+
}
240+
241+
barrier();
242+
memoryBarrierShared();
243+
244+
// Get Even/Odd Values Y for virtual threads
245+
for(uint t = 0u; t < num_virtual_threads; t++)
246+
{
247+
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
248+
249+
uint stage = i;
250+
uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, stage, dataLength);
251+
uint odd_index = even_index + (1u << stage);
252+
253+
_NBL_GLSL_SCRATCH_SHARED_DEFINED_[even_index] = floatBitsToUint(even_values[t].y);
254+
_NBL_GLSL_SCRATCH_SHARED_DEFINED_[odd_index] = floatBitsToUint(odd_values[t].y);
255+
}
256+
257+
barrier();
258+
memoryBarrierShared();
259+
260+
for(uint t = 0u; t < num_virtual_threads; t++)
261+
{
262+
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
263+
264+
uint stage = i + 1u;
265+
uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, stage, dataLength);
266+
uint odd_index = even_index + (1u << stage);
267+
268+
even_values[t].y = uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[even_index]);
269+
odd_values[t].y = uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[odd_index]);
270+
}
271+
}
272+
}
273+
274+
for(uint t = 0u; t < num_virtual_threads; t++)
275+
{
276+
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
277+
278+
uint stage = logTwo - 1;
279+
uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, stage, dataLength); // same as tid
280+
uint odd_index = even_index + (1u << stage);
281+
282+
uvec3 coords_e = nbl_glsl_ext_FFT_getCoordinates(even_index);
283+
uvec3 coords_o = nbl_glsl_ext_FFT_getCoordinates(odd_index);
284+
285+
nbl_glsl_complex complex_value_e = (!is_inverse)
286+
? even_values[t]
287+
: even_values[t] / dataLength;
288+
289+
nbl_glsl_complex complex_value_o = (!is_inverse)
290+
? odd_values[t]
291+
: odd_values[t] / dataLength;
292+
293+
nbl_glsl_ext_FFT_setData(coords_e, channel, complex_value_e);
294+
nbl_glsl_ext_FFT_setData(coords_o, channel, complex_value_o);
295+
}
296+
}
297+
298+
// REMOVE THESE 3 commits later :D
299+
uint nbl_glsl_ext_FFT_calculateTwiddlePower_OLD(in uint threadId, in uint iteration, in uint logTwoN)
300+
{
301+
const uint shiftSuffix = logTwoN - 1u - iteration;
302+
const uint suffixMask = (2u << iteration) - 1u;
303+
return (threadId & suffixMask) << shiftSuffix;
304+
}
305+
306+
nbl_glsl_complex nbl_glsl_ext_FFT_twiddle_OLD(in uint threadId, in uint iteration, in uint logTwoN)
307+
{
308+
uint k = nbl_glsl_ext_FFT_calculateTwiddlePower(threadId, iteration, logTwoN);
309+
return nbl_glsl_expImaginary(-1.0f * 2.0f * nbl_glsl_PI * float(k) / float(1 << logTwoN));
310+
}
311+
312+
nbl_glsl_complex nbl_glsl_ext_FFT_twiddleInverse_OLD(in uint threadId, in uint iteration, in uint logTwoN)
313+
{
314+
return nbl_glsl_complex_conjugate(nbl_glsl_ext_FFT_twiddle(threadId, iteration, logTwoN));
315+
}
316+
317+
void nbl_glsl_ext_FFT_OLD(bool is_inverse)
154318
{
155319
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
156320
// Virtual Threads Calculation
@@ -198,6 +362,9 @@ void nbl_glsl_ext_FFT(bool is_inverse)
198362
shuffled_values[t].x = uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[tid ^ mask]);
199363
}
200364

365+
barrier();
366+
memoryBarrierShared();
367+
201368
// Get Shuffled Values Y for virtual threads
202369
for(uint t = 0u; t < num_virtual_threads; t++)
203370
{
@@ -218,9 +385,9 @@ void nbl_glsl_ext_FFT(bool is_inverse)
218385
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
219386
nbl_glsl_complex shuffled_value = shuffled_values[t];
220387

221-
nbl_glsl_complex twiddle = (is_inverse)
222-
? nbl_glsl_ext_FFT_twiddle(tid, i, logTwo)
223-
: nbl_glsl_ext_FFT_twiddleInverse(tid, i, logTwo);
388+
nbl_glsl_complex twiddle = (!is_inverse)
389+
? nbl_glsl_ext_FFT_twiddle_OLD(tid, i, logTwo)
390+
: nbl_glsl_ext_FFT_twiddleInverse_OLD(tid, i, logTwo);
224391

225392
nbl_glsl_complex this_value = current_values[t];
226393

@@ -236,7 +403,7 @@ void nbl_glsl_ext_FFT(bool is_inverse)
236403
{
237404
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
238405
uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
239-
nbl_glsl_complex complex_value = (is_inverse)
406+
nbl_glsl_complex complex_value = (!is_inverse)
240407
? current_values[t]
241408
: current_values[t] / dataLength;
242409

0 commit comments

Comments
 (0)