Skip to content

Commit 48e514b

Browse files
plan for real DFTs
1 parent e784f8a commit 48e514b

File tree

12 files changed

+77
-80
lines changed

12 files changed

+77
-80
lines changed

examples_tests/49.ComputeFFT/fft_convolve_ifft.comp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
11
// WorkGroup Size
2-
3-
4-
#ifndef _NBL_GLSL_EXT_FFT_MAX_CHANNELS
5-
#define _NBL_GLSL_EXT_FFT_MAX_CHANNELS 4
6-
#endif
7-
82
#ifndef _NBL_GLSL_WORKGROUP_SIZE_
93
#define _NBL_GLSL_WORKGROUP_SIZE_ 256
104
#endif

examples_tests/49.ComputeFFT/last_fft.comp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
#ifndef _NBL_GLSL_EXT_DEFAULT_COMPUTE_FFT_INCLUDED_
2-
#define _NBL_GLSL_EXT_DEFAULT_COMPUTE_FFT_INCLUDED_
3-
4-
// WorkGroup Size
5-
6-
#ifndef _NBL_GLSL_EXT_FFT_MAX_CHANNELS
7-
#define _NBL_GLSL_EXT_FFT_MAX_CHANNELS 4
8-
#endif
9-
101
#ifndef _NBL_GLSL_WORKGROUP_SIZE_
112
#define _NBL_GLSL_WORKGROUP_SIZE_ 256
123
#endif
@@ -83,6 +74,4 @@ void main()
8374
{
8475
nbl_glsl_ext_FFT(nbl_glsl_ext_FFT_Parameters_t_getIsInverse(), ch);
8576
}
86-
}
87-
88-
#endif
77+
}

examples_tests/49.ComputeFFT/main.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ R"===(#version 430 core
7474
7575
#define _NBL_GLSL_WORKGROUP_SIZE_ %u
7676
#define _NBL_GLSL_EXT_FFT_MAX_DIM_SIZE_ %u
77-
#define _NBL_GLSL_EXT_FFT_MAX_ITEMS_PER_THREAD %u
7877
7978
#include "../fft_convolve_ifft.comp"
8079
@@ -83,13 +82,11 @@ R"===(#version 430 core
8382
const size_t extraSize = 32 + 32 + 32 + 32;
8483

8584
constexpr uint32_t DEFAULT_WORK_GROUP_SIZE = 256u;
86-
const uint32_t maxItemsPerThread = (maxPaddedDimensionSize - 1u) / (DEFAULT_WORK_GROUP_SIZE) + 1u;
8785
auto shader = core::make_smart_refctd_ptr<ICPUBuffer>(strlen(sourceFmt)+extraSize+1u);
8886
snprintf(
8987
reinterpret_cast<char*>(shader->getPointer()),shader->getSize(), sourceFmt,
9088
DEFAULT_WORK_GROUP_SIZE,
91-
maxPaddedDimensionSize,
92-
maxItemsPerThread
89+
maxPaddedDimensionSize
9390
);
9491

9592
auto cpuSpecializedShader = core::make_smart_refctd_ptr<ICPUSpecializedShader>(
@@ -196,7 +193,6 @@ R"===(#version 430 core
196193
197194
#define _NBL_GLSL_WORKGROUP_SIZE_ %u
198195
#define _NBL_GLSL_EXT_FFT_MAX_DIM_SIZE_ %u
199-
#define _NBL_GLSL_EXT_FFT_MAX_ITEMS_PER_THREAD %u
200196
201197
#include "../last_fft.comp"
202198
@@ -205,13 +201,11 @@ R"===(#version 430 core
205201
const size_t extraSize = 32 + 32 + 32 + 32;
206202

207203
constexpr uint32_t DEFAULT_WORK_GROUP_SIZE = 256u;
208-
const uint32_t maxItemsPerThread = (maxPaddedDimensionSize - 1u) / (DEFAULT_WORK_GROUP_SIZE) + 1u;
209204
auto shader = core::make_smart_refctd_ptr<ICPUBuffer>(strlen(sourceFmt)+extraSize+1u);
210205
snprintf(
211206
reinterpret_cast<char*>(shader->getPointer()),shader->getSize(), sourceFmt,
212207
DEFAULT_WORK_GROUP_SIZE,
213-
maxPaddedDimensionSize,
214-
maxItemsPerThread
208+
maxPaddedDimensionSize
215209
);
216210

217211
auto cpuSpecializedShader = core::make_smart_refctd_ptr<ICPUSpecializedShader>(

include/nbl/builtin/glsl/ext/FFT/default_compute_fft.comp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
1-
#ifndef _NBL_GLSL_EXT_DEFAULT_COMPUTE_FFT_INCLUDED_
2-
#define _NBL_GLSL_EXT_DEFAULT_COMPUTE_FFT_INCLUDED_
3-
4-
// WorkGroup Size
5-
61
#ifndef USE_SSBO_FOR_INPUT
72
#error "USE_SSBO_FOR_INPUT should be defined."
83
#endif
94

10-
#ifndef _NBL_GLSL_EXT_FFT_MAX_CHANNELS
11-
#define _NBL_GLSL_EXT_FFT_MAX_CHANNELS 4
12-
#endif
13-
145
#ifndef _NBL_GLSL_WORKGROUP_SIZE_
156
#define _NBL_GLSL_WORKGROUP_SIZE_ 256
167
#endif
@@ -127,6 +118,4 @@ void main()
127118
{
128119
nbl_glsl_ext_FFT(nbl_glsl_ext_FFT_Parameters_t_getIsInverse(), ch);
129120
}
130-
}
131-
132-
#endif
121+
}

include/nbl/builtin/glsl/ext/FFT/fft.glsl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,10 @@
88
#include <nbl/builtin/glsl/math/complex.glsl>
99
#include <nbl/builtin/glsl/ext/FFT/parameters.glsl>
1010

11-
#ifndef _NBL_GLSL_EXT_FFT_MAX_CHANNELS
12-
#error "_NBL_GLSL_EXT_FFT_MAX_CHANNELS should be defined."
13-
#endif
14-
1511
#ifndef _NBL_GLSL_EXT_FFT_MAX_DIM_SIZE_
1612
#error "_NBL_GLSL_EXT_FFT_MAX_DIM_SIZE_ should be defined."
1713
#endif
1814

19-
#ifndef _NBL_GLSL_EXT_FFT_MAX_ITEMS_PER_THREAD
20-
#error "_NBL_GLSL_EXT_FFT_MAX_ITEMS_PER_THREAD should be defined."
21-
#endif
22-
2315
#include "nbl/builtin/glsl/workgroup/shared_fft.glsl"
2416

2517
// Push Constants
@@ -72,7 +64,8 @@ uvec3 nbl_glsl_ext_FFT_getCoordinates(in uint tidx)
7264
#include "nbl/builtin/glsl/workgroup/fft.glsl"
7365

7466

75-
nbl_glsl_complex nbl_glsl_ext_FFT_impl_values[_NBL_GLSL_EXT_FFT_MAX_ITEMS_PER_THREAD*2u]; // TODO: redo later
67+
nbl_glsl_complex nbl_glsl_ext_FFT_impl_values[(_NBL_GLSL_EXT_FFT_MAX_DIM_SIZE_-1u)/_NBL_GLSL_WORKGROUP_SIZE_+1u];
68+
7669
void nbl_glsl_ext_FFT_loop(in bool is_inverse, in uint virtual_thread_count, in uint step)
7770
{
7871
for(uint t=0u; t<virtual_thread_count; t++)
@@ -90,7 +83,7 @@ void nbl_glsl_ext_FFT_loop(in bool is_inverse, in uint virtual_thread_count, in
9083
nbl_glsl_FFT_DIF_radix2(twiddle,nbl_glsl_ext_FFT_impl_values[lo_ix],nbl_glsl_ext_FFT_impl_values[hi_ix]);
9184
}
9285
}
93-
// TODO: try radix-4 or even radix-8 for perf
86+
9487
void nbl_glsl_ext_FFT(bool is_inverse, uint channel)
9588
{
9689
// Virtual Threads Calculation
Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
#version 430 core
2-
3-
#ifndef _NBL_GLSL_EXT_FFT_NORMALIZATION_INCLUDED_
4-
#define _NBL_GLSL_EXT_FFT_NORMALIZATION_INCLUDED_
5-
62
layout(local_size_x=256, local_size_y=1, local_size_z=1) in;
73

8-
#define complex_value vec2
4+
#include <nbl/builtin/glsl/math/complex.glsl>
95

106
layout(set=0, binding=0) restrict readonly buffer InBuffer
117
{
@@ -20,8 +16,6 @@ layout(set=0, binding=1) restrict buffer OutBuffer
2016
void main()
2117
{
2218
float power = length(in_data[0]);
23-
vec2 normalized_data = in_data[gl_GlobalInvocationID.x];// / power;
19+
vec2 normalized_data = in_data[gl_GlobalInvocationID.x]/power;
2420
out_data[gl_GlobalInvocationID.x] = normalized_data;
25-
}
26-
27-
#endif
21+
}

include/nbl/builtin/glsl/math/complex.glsl

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88
#include <nbl/builtin/glsl/math/constants.glsl>
99
#include <nbl/builtin/glsl/math/functions.glsl>
1010

11+
#define nbl_glsl_complex16_t uint
12+
1113
#define nbl_glsl_complex vec2
1214
#define nbl_glsl_cvec2 mat2
1315
#define nbl_glsl_cvec3 mat3x2
1416
#define nbl_glsl_cvec4 mat4x2
1517

1618
nbl_glsl_complex nbl_glsl_expImaginary(in float _theta)
1719
{
18-
float r = cos(_theta);
19-
float i = sin(_theta);
20-
return vec2(r, i);
20+
nbl_glsl_complex retval;
21+
nbl_glsl_sincos(_theta,retval.y,retval.x);
22+
return retval;
2123
}
2224

2325
nbl_glsl_complex nbl_glsl_complex_mul(in nbl_glsl_complex rhs, in nbl_glsl_complex lhs)
@@ -32,16 +34,47 @@ nbl_glsl_complex nbl_glsl_complex_add(in nbl_glsl_complex rhs, in nbl_glsl_compl
3234
return rhs + lhs;
3335
}
3436

37+
nbl_glsl_complex16_t nbl_glsl_complex16_t_conjugate(in nbl_glsl_complex16_t complex) {
38+
return complex^0x80000000u;
39+
}
3540
nbl_glsl_complex nbl_glsl_complex_conjugate(in nbl_glsl_complex complex) {
36-
return complex * vec2(1, -1);
41+
return nbl_glsl_complex(complex.x,-complex.y);
3742
}
3843

3944

4045
// FFT
41-
#include <nbl/builtin/glsl/math/complex.glsl>
46+
nbl_glsl_complex nbl_glsl_FFT_half_twiddle(in uint k, in float N)
47+
{
48+
const float arg = -2.f*nbl_glsl_PI*float(k)/N;
49+
nbl_glsl_complex retval;
50+
retval.x = cos(arg);
51+
retval.y = sqrt(1.f-retval.x*retval.x); // twiddle is always half the range, so no conditional -1.f needed
52+
return retval;
53+
}
54+
nbl_glsl_complex nbl_glsl_FFT_half_twiddle(in uint k, in uint logTwoN)
55+
{
56+
return nbl_glsl_FFT_half_twiddle(k,float(1<<logTwoN));
57+
}
58+
59+
nbl_glsl_complex nbl_glsl_FFT_half_twiddle(in bool is_inverse, in uint k, in float N)
60+
{
61+
nbl_glsl_complex twiddle = nbl_glsl_FFT_half_twiddle(k,N);
62+
if (is_inverse)
63+
return nbl_glsl_complex_conjugate(twiddle);
64+
return twiddle;
65+
}
66+
nbl_glsl_complex nbl_glsl_FFT_half_twiddle(in bool is_inverse, in uint k, in uint logTwoN)
67+
{
68+
return nbl_glsl_FFT_half_twiddle(is_inverse,k,float(1<<logTwoN));
69+
}
70+
71+
4272
nbl_glsl_complex nbl_glsl_FFT_twiddle(in uint k, in float N)
4373
{
44-
return nbl_glsl_expImaginary(-2.f*nbl_glsl_PI*float(k)/N);
74+
nbl_glsl_complex retval;
75+
retval.x = cos(-2.f*nbl_glsl_PI*float(k)/N);
76+
retval.y = sqrt(1.f-retval.x*retval.x); // twiddle is always half the range, so no conditional -1.f needed
77+
return retval;
4578
}
4679
nbl_glsl_complex nbl_glsl_FFT_twiddle(in uint k, in uint logTwoN)
4780
{
@@ -50,7 +83,7 @@ nbl_glsl_complex nbl_glsl_FFT_twiddle(in uint k, in uint logTwoN)
5083

5184
nbl_glsl_complex nbl_glsl_FFT_twiddle(in bool is_inverse, in uint k, in float N)
5285
{
53-
nbl_glsl_complex twiddle = nbl_glsl_FFT_twiddle(k, N);
86+
nbl_glsl_complex twiddle = nbl_glsl_FFT_twiddle(k,N);
5487
if (is_inverse)
5588
return nbl_glsl_complex_conjugate(twiddle);
5689
return twiddle;
@@ -60,6 +93,8 @@ nbl_glsl_complex nbl_glsl_FFT_twiddle(in bool is_inverse, in uint k, in uint log
6093
return nbl_glsl_FFT_twiddle(is_inverse,k,float(1<<logTwoN));
6194
}
6295

96+
97+
6398
// decimation in time
6499
void nbl_glsl_FFT_DIT_radix2(in nbl_glsl_complex twiddle, inout nbl_glsl_complex lo, inout nbl_glsl_complex hi)
65100
{

include/nbl/builtin/glsl/math/functions.glsl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ void nbl_glsl_sincos(in float theta, out float s, out float c)
191191
{
192192
c = cos(theta);
193193
s = sqrt(1.0-c*c);
194-
s = theta<0.0 ? -s:s; // TODO: do with XOR
194+
s = theta<0.0 ? -s:s; // TODO: test with XOR
195+
//s = uintBitsToFloat(floatBitsToUint(s)^(floatBitsToUint(theta)&0x80000000u));
195196
}
196197

197198
mat2x3 nbl_glsl_frisvad(in vec3 n)

include/nbl/builtin/glsl/subgroup/fft.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
//TODO: optimization for DFT of real signal
1616

1717
// TODO: with stockham or something that does not require stupid shuffles to extract and pack
18+
// https://cnx.org/contents/[email protected]:1aiTU8is@6/Alternate-FFT-Structures
19+
// These twiddle factors can be precomputed once and stored in an array in computer memory, and accessed in the FFT algorithm by table lookup. This simple technique yields very substantial savings and is almost always used in practice.
1820
void nbl_glsl_subgroupFFT_loop(in bool is_inverse, in uint stride, inout nbl_glsl_complex lo, inout nbl_glsl_complex hi)
1921
{
2022
const uint sub_ix = nbl_glsl_SubgroupInvocationID&(stride-1u);

include/nbl/builtin/glsl/workgroup/fft.glsl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
#endif
2828

2929

30-
//TODO: optimization for DFT of real signal
31-
30+
//TODO: try radix-4 or even radix-8 for perf
3231

3332
void nbl_glsl_workgroupFFT_loop(in bool is_inverse, in uint stride)
3433
{
@@ -42,7 +41,10 @@ void nbl_glsl_workgroupFFT_loop(in bool is_inverse, in uint stride)
4241
nbl_glsl_complex low = nbl_glsl_complex(uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[lo_x_ix]),uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[lo_y_ix]));
4342
nbl_glsl_complex high = nbl_glsl_complex(uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[hi_x_ix]),uintBitsToFloat(_NBL_GLSL_SCRATCH_SHARED_DEFINED_[hi_y_ix]));
4443

45-
nbl_glsl_complex twiddle = nbl_glsl_FFT_twiddle(is_inverse,sub_ix,float(stride<<1u));
44+
nbl_glsl_complex twiddle = nbl_glsl_complex(1.f,0.f);
45+
if (stride!=1u)
46+
twiddle = nbl_glsl_FFT_twiddle(is_inverse,sub_ix,float(stride<<1u));
47+
4648
if (is_inverse)
4749
nbl_glsl_FFT_DIT_radix2(twiddle,low,high);
4850
else
@@ -97,16 +99,22 @@ void nbl_glsl_workgroupFFT(in bool is_inverse, inout nbl_glsl_complex lo, inout
9799
if (is_inverse)
98100
{
99101
nbl_glsl_FFT_DIT_radix2(nbl_glsl_FFT_twiddle(true,gl_LocalInvocationIndex,doubleWorkgroupSize),lo,hi);
100-
101-
const float doubleSubgroupSize = float(nbl_glsl_SubgroupSize<<1u);
102-
lo /= doubleSubgroupSize;
103-
hi /= doubleSubgroupSize;
104-
const float scaleFactor = float(nbl_glsl_SubgroupSize<<1u)/doubleWorkgroupSize;
105-
lo *= scaleFactor;
106-
hi *= scaleFactor;
102+
103+
lo /= doubleWorkgroupSize;
104+
hi /= doubleWorkgroupSize;
107105
}
108106
}
109107

108+
#if 0 // TODO
109+
// Computes Forward FFT of two real signals
110+
void nbl_glsl_workgroupRealFFT(in bool is_inverse, in float sequenceALo, in float sequenceAHi, in float sequenceBLo, in float sequenceBHi)
111+
{
112+
nbl_glsl_complex lo = nbl_glsl_complex(sequenceALo,sequenceBLo);
113+
nbl_glsl_complex hi = nbl_glsl_complex(sequenceAHi,sequenceBHi);
114+
nbl_glsl_workgroupFFT(false,lo,hi);
115+
// extract aDFT and bDFT by using sorensens method
116+
}
117+
#endif
110118

111119

112120
#endif

0 commit comments

Comments
 (0)