25
25
#error "_NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_ should be defined."
26
26
#endif
27
27
28
- // TODO: Investigate why +1 solves all glitches
29
- #define _NBL_GLSL_EXT_FFT_SHARED_SIZE_NEEDED_ _NBL_GLSL_EXT_FFT_MAX_DIM_SIZE_ + 1
28
+ #define _NBL_GLSL_EXT_FFT_SHARED_SIZE_NEEDED_ _NBL_GLSL_EXT_FFT_MAX_DIM_SIZE_
30
29
31
30
#ifdef _NBL_GLSL_SCRATCH_SHARED_DEFINED_
32
31
#if NBL_GLSL_LESS(_NBL_GLSL_SCRATCH_SHARED_SIZE_DEFINED_,_NBL_GLSL_EXT_FFT_SHARED_SIZE_NEEDED_)
@@ -103,24 +102,6 @@ uint nbl_glsl_ext_FFT_Parameters_t_getPaddingType() {
103
102
return (params.dimension.w) & 0x000000ff;
104
103
}
105
104
106
- uint nbl_glsl_ext_FFT_calculateTwiddlePower(in uint threadId, in uint iteration, in uint logTwoN)
107
- {
108
- const uint shiftSuffix = logTwoN - 1u - iteration; // can we assert that iteration<logTwoN always?? yes
109
- const uint suffixMask = (2u << iteration) - 1u;
110
- return (threadId & suffixMask) << shiftSuffix;
111
- }
112
-
113
- nbl_glsl_complex nbl_glsl_ext_FFT_twiddle(in uint threadId, in uint iteration, in uint logTwoN)
114
- {
115
- uint k = nbl_glsl_ext_FFT_calculateTwiddlePower(threadId, iteration, logTwoN);
116
- return nbl_glsl_expImaginary(- 1 .0f * 2 .0f * nbl_glsl_PI * float (k) / float (1 << logTwoN));
117
- }
118
-
119
- nbl_glsl_complex nbl_glsl_ext_FFT_twiddleInverse(in uint threadId, in uint iteration, in uint logTwoN)
120
- {
121
- return nbl_glsl_complex_conjugate(nbl_glsl_ext_FFT_twiddle(threadId, iteration, logTwoN));
122
- }
123
-
124
105
uint nbl_glsl_ext_FFT_getChannel()
125
106
{
126
107
uint direction = nbl_glsl_ext_FFT_Parameters_t_getDirection();
@@ -150,7 +131,190 @@ uint nbl_glsl_ext_FFT_getDimLength(uvec3 dimension)
150
131
return dimension[direction];
151
132
}
152
133
134
+ uint nbl_glsl_ext_FFT_calculateTwiddlePower(in uint threadId, in uint iteration, in uint logTwoN)
135
+ {
136
+ const uint shiftSuffix = logTwoN - 1u - iteration;
137
+ const uint suffixMask = (1u << iteration) - 1u;
138
+ return (threadId & suffixMask) << shiftSuffix;
139
+ }
140
+
141
+ nbl_glsl_complex nbl_glsl_ext_FFT_twiddle(in uint threadId, in uint iteration, in uint logTwoN)
142
+ {
143
+ uint k = nbl_glsl_ext_FFT_calculateTwiddlePower(threadId, iteration, logTwoN);
144
+ return nbl_glsl_expImaginary(- 1 .0f * 2 .0f * nbl_glsl_PI * float (k) / float (1 << logTwoN));
145
+ }
146
+
147
+ nbl_glsl_complex nbl_glsl_ext_FFT_twiddleInverse(in uint threadId, in uint iteration, in uint logTwoN)
148
+ {
149
+ return nbl_glsl_complex_conjugate(nbl_glsl_ext_FFT_twiddle(threadId, iteration, logTwoN));
150
+ }
151
+
152
+ uint nbl_glsl_ext_FFT_getEvenIndex(in uint threadId, in uint iteration, in uint N) {
153
+ return ((threadId & (N - (1u << iteration))) << 1u) | (threadId & ((1u << iteration) - 1u));
154
+ }
155
+
153
156
void nbl_glsl_ext_FFT(bool is_inverse)
157
+ {
158
+ nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
159
+ // Virtual Threads Calculation
160
+ uint dataLength = nbl_glsl_ext_FFT_getDimLength(nbl_glsl_ext_FFT_Parameters_t_getPaddedDimensions());
161
+ uint num_virtual_threads = ((dataLength >> 1 )- 1u)/ (_NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_)+ 1u;
162
+ uint thread_offset = gl_LocalInvocationIndex;
163
+
164
+ uint channel = nbl_glsl_ext_FFT_getChannel();
165
+
166
+ // Pass 0: Bit Reversal
167
+ uint leadingZeroes = nbl_glsl_clz(dataLength) + 1u;
168
+ uint logTwo = 32u - leadingZeroes;
169
+
170
+ nbl_glsl_complex even_values[_NBL_GLSL_EXT_FFT_MAX_ITEMS_PER_THREAD]; // should be half the prev version
171
+ nbl_glsl_complex odd_values[_NBL_GLSL_EXT_FFT_MAX_ITEMS_PER_THREAD];
172
+
173
+ // Load Initial Values into Local Mem (bit reversed indices)
174
+ for (uint t = 0u; t < num_virtual_threads; t++ )
175
+ {
176
+ // TODO: read coords and shuffle with shared memory (scattered access of fast memory)
177
+ uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
178
+
179
+ uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, 0 , dataLength); // same as tid * 2
180
+
181
+ uvec3 coords_e = nbl_glsl_ext_FFT_getCoordinates(even_index);
182
+ uvec3 bitReversedCoords_e = nbl_glsl_ext_FFT_getBitReversedCoordinates(coords_e, leadingZeroes);
183
+ even_values[t] = nbl_glsl_ext_FFT_getPaddedData(bitReversedCoords_e, channel);
184
+
185
+ uvec3 coords_o = nbl_glsl_ext_FFT_getCoordinates(even_index + 1 );
186
+ uvec3 bitReversedCoords_o = nbl_glsl_ext_FFT_getBitReversedCoordinates(coords_o, leadingZeroes);
187
+ odd_values[t] = nbl_glsl_ext_FFT_getPaddedData(bitReversedCoords_o, channel);
188
+ }
189
+
190
+ // For loop for each stage of the FFT (each virtual thread computes 1 buttefly)
191
+ for (uint i = 0u; i < logTwo; ++ i)
192
+ {
193
+ // Computation of each virtual thread
194
+ for (uint t = 0u; t < num_virtual_threads; t++ )
195
+ {
196
+ uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
197
+ nbl_glsl_complex even_value = even_values[t];
198
+ nbl_glsl_complex odd_value = odd_values[t];
199
+
200
+ nbl_glsl_complex twiddle = (! is_inverse)
201
+ ? nbl_glsl_ext_FFT_twiddle(tid, i, logTwo)
202
+ : nbl_glsl_ext_FFT_twiddleInverse(tid, i, logTwo);
203
+
204
+ nbl_glsl_complex cmplx_mul = nbl_glsl_complex_mul(twiddle, odd_value);
205
+
206
+ even_values[t] = even_value + cmplx_mul;
207
+ odd_values[t] = even_value - cmplx_mul;
208
+ }
209
+
210
+ // Exchange Even and Odd Values with Other Threads (or maybe this thread)
211
+ if (i < logTwo - 1 )
212
+ {
213
+ // Get Even/Odd Values X for virtual threads
214
+ for (uint t = 0u; t < num_virtual_threads; t++ )
215
+ {
216
+ uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
217
+
218
+ uint stage = i;
219
+ uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, stage, dataLength);
220
+ uint odd_index = even_index + (1u << stage);
221
+
222
+ _NBL_GLSL_SCRATCH_SHARED_DEFINED_[even_index] = floatBitsToUint(even_values[t].x);
223
+ _NBL_GLSL_SCRATCH_SHARED_DEFINED_[odd_index] = floatBitsToUint(odd_values[t].x);
224
+ }
225
+
226
+ barrier();
227
+ memoryBarrierShared();
228
+
229
+ for (uint t = 0u; t < num_virtual_threads; t++ )
230
+ {
231
+ uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
232
+
233
+ uint stage = i + 1u;
234
+ uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, stage, dataLength);
235
+ uint odd_index = even_index + (1u << stage);
236
+
237
+ even_values[t].x = uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[even_index]);
238
+ odd_values[t].x = uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[odd_index]);
239
+ }
240
+
241
+ barrier();
242
+ memoryBarrierShared();
243
+
244
+ // Get Even/Odd Values Y for virtual threads
245
+ for (uint t = 0u; t < num_virtual_threads; t++ )
246
+ {
247
+ uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
248
+
249
+ uint stage = i;
250
+ uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, stage, dataLength);
251
+ uint odd_index = even_index + (1u << stage);
252
+
253
+ _NBL_GLSL_SCRATCH_SHARED_DEFINED_[even_index] = floatBitsToUint(even_values[t].y);
254
+ _NBL_GLSL_SCRATCH_SHARED_DEFINED_[odd_index] = floatBitsToUint(odd_values[t].y);
255
+ }
256
+
257
+ barrier();
258
+ memoryBarrierShared();
259
+
260
+ for (uint t = 0u; t < num_virtual_threads; t++ )
261
+ {
262
+ uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
263
+
264
+ uint stage = i + 1u;
265
+ uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, stage, dataLength);
266
+ uint odd_index = even_index + (1u << stage);
267
+
268
+ even_values[t].y = uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[even_index]);
269
+ odd_values[t].y = uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[odd_index]);
270
+ }
271
+ }
272
+ }
273
+
274
+ for (uint t = 0u; t < num_virtual_threads; t++ )
275
+ {
276
+ uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
277
+
278
+ uint stage = logTwo - 1 ;
279
+ uint even_index = nbl_glsl_ext_FFT_getEvenIndex(tid, stage, dataLength); // same as tid
280
+ uint odd_index = even_index + (1u << stage);
281
+
282
+ uvec3 coords_e = nbl_glsl_ext_FFT_getCoordinates(even_index);
283
+ uvec3 coords_o = nbl_glsl_ext_FFT_getCoordinates(odd_index);
284
+
285
+ nbl_glsl_complex complex_value_e = (! is_inverse)
286
+ ? even_values[t]
287
+ : even_values[t] / dataLength;
288
+
289
+ nbl_glsl_complex complex_value_o = (! is_inverse)
290
+ ? odd_values[t]
291
+ : odd_values[t] / dataLength;
292
+
293
+ nbl_glsl_ext_FFT_setData(coords_e, channel, complex_value_e);
294
+ nbl_glsl_ext_FFT_setData(coords_o, channel, complex_value_o);
295
+ }
296
+ }
297
+
298
+ // REMOVE THESE 3 commits later :D
299
+ uint nbl_glsl_ext_FFT_calculateTwiddlePower_OLD(in uint threadId, in uint iteration, in uint logTwoN)
300
+ {
301
+ const uint shiftSuffix = logTwoN - 1u - iteration;
302
+ const uint suffixMask = (2u << iteration) - 1u;
303
+ return (threadId & suffixMask) << shiftSuffix;
304
+ }
305
+
306
+ nbl_glsl_complex nbl_glsl_ext_FFT_twiddle_OLD(in uint threadId, in uint iteration, in uint logTwoN)
307
+ {
308
+ uint k = nbl_glsl_ext_FFT_calculateTwiddlePower(threadId, iteration, logTwoN);
309
+ return nbl_glsl_expImaginary(- 1 .0f * 2 .0f * nbl_glsl_PI * float (k) / float (1 << logTwoN));
310
+ }
311
+
312
+ nbl_glsl_complex nbl_glsl_ext_FFT_twiddleInverse_OLD(in uint threadId, in uint iteration, in uint logTwoN)
313
+ {
314
+ return nbl_glsl_complex_conjugate(nbl_glsl_ext_FFT_twiddle(threadId, iteration, logTwoN));
315
+ }
316
+
317
+ void nbl_glsl_ext_FFT_OLD(bool is_inverse)
154
318
{
155
319
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
156
320
// Virtual Threads Calculation
@@ -198,6 +362,9 @@ void nbl_glsl_ext_FFT(bool is_inverse)
198
362
shuffled_values[t].x = uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[tid ^ mask]);
199
363
}
200
364
365
+ barrier();
366
+ memoryBarrierShared();
367
+
201
368
// Get Shuffled Values Y for virtual threads
202
369
for (uint t = 0u; t < num_virtual_threads; t++ )
203
370
{
@@ -218,9 +385,9 @@ void nbl_glsl_ext_FFT(bool is_inverse)
218
385
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
219
386
nbl_glsl_complex shuffled_value = shuffled_values[t];
220
387
221
- nbl_glsl_complex twiddle = (is_inverse)
222
- ? nbl_glsl_ext_FFT_twiddle (tid, i, logTwo)
223
- : nbl_glsl_ext_FFT_twiddleInverse (tid, i, logTwo);
388
+ nbl_glsl_complex twiddle = (! is_inverse)
389
+ ? nbl_glsl_ext_FFT_twiddle_OLD (tid, i, logTwo)
390
+ : nbl_glsl_ext_FFT_twiddleInverse_OLD (tid, i, logTwo);
224
391
225
392
nbl_glsl_complex this_value = current_values[t];
226
393
@@ -236,7 +403,7 @@ void nbl_glsl_ext_FFT(bool is_inverse)
236
403
{
237
404
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
238
405
uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
239
- nbl_glsl_complex complex_value = (is_inverse)
406
+ nbl_glsl_complex complex_value = (! is_inverse)
240
407
? current_values[t]
241
408
: current_values[t] / dataLength;
242
409
0 commit comments