Skip to content

Commit b1b231f

Browse files
committed
Some Fixes and Cleanups
1 parent 95fe2b7 commit b1b231f

File tree

2 files changed

+39
-84
lines changed

2 files changed

+39
-84
lines changed

examples_tests/49.ComputeFFT/convolve.comp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,21 @@
88

99
layout(local_size_x=256, local_size_y=1, local_size_z=1) in;
1010

11-
struct nbl_glsl_ext_FFT_output_t
12-
{
13-
vec2 complex_value;
14-
};
11+
#define complex_value vec2
1512

1613
layout(set=0, binding=0) restrict readonly buffer SrcBuffer
1714
{
18-
nbl_glsl_ext_FFT_output_t src_data[];
15+
complex_value src_data[];
1916
};
2017

2118
layout(set=0, binding=1) restrict readonly buffer KernelBuffer
2219
{
23-
nbl_glsl_ext_FFT_output_t ker_data[];
20+
complex_value ker_data[];
2421
};
2522

2623
layout(set=0, binding=2) restrict buffer OutputBuffer
2724
{
28-
nbl_glsl_ext_FFT_output_t out_data[];
25+
complex_value out_data[];
2926
};
3027

3128
layout(push_constant) uniform PushConstants
@@ -38,7 +35,7 @@ void main()
3835
{
3936
// if not already normalized -> divide by power
4037
// float power = length(ker_data[0].complex_value);
41-
vec2 kerData = ker_data[gl_GlobalInvocationID.x].complex_value;
42-
out_data[gl_GlobalInvocationID.x].complex_value =
43-
nbl_glsl_complex_mul(src_data[gl_GlobalInvocationID.x].complex_value, kerData);
38+
vec2 kerData = ker_data[gl_GlobalInvocationID.x];
39+
out_data[gl_GlobalInvocationID.x] =
40+
nbl_glsl_complex_mul(src_data[gl_GlobalInvocationID.x], kerData);
4441
}

include/nbl/builtin/glsl/ext/FFT/fft.glsl

Lines changed: 32 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -74,95 +74,53 @@ void nbl_glsl_ext_FFT_setData(in uvec3 coordinate, in uint channel, in vec2 comp
7474
#endif
7575

7676
// Count Leading Zeroes (naive?)
77-
uint clz(in uint x)
77+
uint nbl_glsl_ext_FFT_clz(in uint x)
7878
{
79-
uint n = 0;
80-
if (x == 0) { return 32; }
81-
if (x <= 0x0000ffff) { n += 16; x <<= 16; }
82-
if (x <= 0x00ffffff) { n += 8; x <<= 8; }
83-
if (x <= 0x0fffffff) { n += 4; x <<= 4; }
84-
if (x <= 0x3fffffff) { n += 2; x <<= 2; }
85-
if (x <= 0x7fffffff) { n++; };
86-
return n;
79+
return 31 - findMSB(x);
8780
}
8881

89-
uint reverseBits(in uint x)
82+
uint nbl_glsl_ext_FFT_reverseBits(in uint x)
9083
{
91-
uint count = 4 * 8 - 1;
92-
uint reverse_num = x;
93-
94-
x >>= 1;
95-
while(x > 0)
96-
{
97-
reverse_num <<= 1;
98-
reverse_num |= x & 1;
99-
x >>= 1;
100-
count--;
101-
}
102-
reverse_num <<= count;
103-
return reverse_num;
84+
return bitfieldReverse(x);
10485
}
10586

106-
uint calculate_twiddle_power(in uint threadId, in uint iteration, in uint logTwoN, in uint N)
87+
uint nbl_glsl_ext_FFT_calculateTwiddlePower(in uint threadId, in uint iteration, in uint logTwoN, in uint N)
10788
{
10889
return (threadId & ((N / (1u << (logTwoN - iteration))) * 2 - 1)) * ((1u << (logTwoN - iteration)) / 2);;
10990
}
11091

111-
vec2 twiddle(in uint threadId, in uint iteration, in uint logTwoN, in uint N)
92+
vec2 nbl_glsl_ext_FFT_twiddle(in uint threadId, in uint iteration, in uint logTwoN, in uint N)
11293
{
113-
uint k = calculate_twiddle_power(threadId, iteration, logTwoN, N);
94+
uint k = nbl_glsl_ext_FFT_calculateTwiddlePower(threadId, iteration, logTwoN, N);
11495
return nbl_glsl_eITheta(-1 * 2 * nbl_glsl_PI * k / N);
11596
}
11697

117-
vec2 twiddle_inv(in uint threadId, in uint iteration, in uint logTwoN, in uint N)
98+
vec2 nbl_gnbl_glsl_ext_FFT_twiddleInverse(in uint threadId, in uint iteration, in uint logTwoN, in uint N)
11899
{
119-
float k = calculate_twiddle_power(threadId, iteration, logTwoN, N);
120-
return nbl_glsl_eITheta(2 * nbl_glsl_PI * k / N);
100+
return nbl_glsl_complex_conjugate(nbl_glsl_ext_FFT_twiddle(threadId, iteration, logTwoN, N));
121101
}
122102

123-
uint getChannel()
103+
uint nbl_glsl_ext_FFT_getChannel()
124104
{
125-
if(pc.direction == _NBL_GLSL_EXT_FFT_DIRECTION_X_) {
126-
return gl_WorkGroupID.x;
127-
} else if (pc.direction == _NBL_GLSL_EXT_FFT_DIRECTION_Y_) {
128-
return gl_WorkGroupID.y;
129-
} else if (pc.direction == _NBL_GLSL_EXT_FFT_DIRECTION_Z_) {
130-
return gl_WorkGroupID.z;
131-
} else {
132-
return 0;
133-
}
105+
return gl_WorkGroupID[pc.direction];
134106
}
135107

136-
uvec3 getCoordinates(in uint tidx)
108+
uvec3 nbl_glsl_ext_FFT_getCoordinates(in uint tidx)
137109
{
138-
if(pc.direction == _NBL_GLSL_EXT_FFT_DIRECTION_X_) {
139-
return uvec3(tidx, gl_WorkGroupID.y, gl_WorkGroupID.z);
140-
} else if (pc.direction == _NBL_GLSL_EXT_FFT_DIRECTION_Y_) {
141-
return uvec3(gl_WorkGroupID.x, tidx, gl_WorkGroupID.z);
142-
} else if (pc.direction == _NBL_GLSL_EXT_FFT_DIRECTION_Z_) {
143-
return uvec3(gl_WorkGroupID.x, gl_WorkGroupID.y, tidx);
144-
} else {
145-
return uvec3(0,0,0);
146-
}
110+
uvec3 tmp = gl_WorkGroupID;
111+
tmp[pc.direction] = tidx;
112+
return tmp;
147113
}
148114

149-
uvec3 getBitReversedCoordinates(in uvec3 coords, in uint leadingZeroes)
115+
uvec3 nbl_glsl_ext_FFT_getBitReversedCoordinates(in uvec3 coords, in uint leadingZeroes)
150116
{
151-
if(pc.direction == _NBL_GLSL_EXT_FFT_DIRECTION_X_) {
152-
uint bitReversedIndex = reverseBits(coords.x) >> leadingZeroes;
153-
return uvec3(bitReversedIndex, coords.y, coords.z);
154-
} else if (pc.direction == _NBL_GLSL_EXT_FFT_DIRECTION_Y_) {
155-
uint bitReversedIndex = reverseBits(coords.y) >> leadingZeroes;
156-
return uvec3(coords.x, bitReversedIndex, coords.z);
157-
} else if (pc.direction == _NBL_GLSL_EXT_FFT_DIRECTION_Z_) {
158-
uint bitReversedIndex = reverseBits(coords.z) >> leadingZeroes;
159-
return uvec3(coords.x, coords.y, bitReversedIndex);
160-
} else {
161-
return uvec3(0,0,0);
162-
}
117+
uint bitReversedIndex = nbl_glsl_ext_FFT_reverseBits(coords[pc.direction]) >> leadingZeroes;
118+
uvec3 tmp = coords;
119+
tmp[pc.direction] = bitReversedIndex;
120+
return tmp;
163121
}
164122

165-
uint getDimLength(uvec3 dimension)
123+
uint nbl_glsl_ext_FFT_getDimLength(uvec3 dimension)
166124
{
167125
uint dataLength = 0;
168126

@@ -235,14 +193,14 @@ vec2 nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint channel) {
235193
void nbl_glsl_ext_FFT()
236194
{
237195
// Virtual Threads Calculation
238-
uint dataLength = getDimLength(pc.padded_dimension);
196+
uint dataLength = nbl_glsl_ext_FFT_getDimLength(pc.padded_dimension);
239197
uint num_virtual_threads = uint(ceil(float(dataLength) / float(_NBL_GLSL_EXT_FFT_BLOCK_SIZE_X_DEFINED_)));
240198
uint thread_offset = gl_LocalInvocationIndex;
241199

242-
uint channel = getChannel();
200+
uint channel = nbl_glsl_ext_FFT_getChannel();
243201

244202
// Pass 0: Bit Reversal
245-
uint leadingZeroes = clz(dataLength) + 1;
203+
uint leadingZeroes = nbl_glsl_ext_FFT_clz(dataLength) + 1;
246204
uint logTwo = 32 - leadingZeroes;
247205

248206
vec2 current_values[_NBL_GLSL_EXT_FFT_MAX_ITEMS_PER_THREAD];
@@ -252,8 +210,8 @@ void nbl_glsl_ext_FFT()
252210
for(uint t = 0; t < num_virtual_threads; t++)
253211
{
254212
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_BLOCK_SIZE_X_DEFINED_;
255-
uvec3 coords = getCoordinates(tid);
256-
uvec3 bitReversedCoords = getBitReversedCoordinates(coords, leadingZeroes);
213+
uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
214+
uvec3 bitReversedCoords = nbl_glsl_ext_FFT_getBitReversedCoordinates(coords, leadingZeroes);
257215

258216
current_values[t] = nbl_glsl_ext_FFT_getPaddedData(bitReversedCoords, channel);
259217
}
@@ -299,24 +257,24 @@ void nbl_glsl_ext_FFT()
299257
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_BLOCK_SIZE_X_DEFINED_;
300258
vec2 shuffled_value = shuffled_values[t];
301259

302-
vec2 twiddle = (0 == pc.is_inverse)
303-
? twiddle(tid, i, logTwo, dataLength)
304-
: twiddle_inv(tid, i, logTwo, dataLength);
260+
vec2 nbl_glsl_ext_FFT_twiddle = (0 == pc.is_inverse)
261+
? nbl_glsl_ext_FFT_twiddle(tid, i, logTwo, dataLength)
262+
: nbl_gnbl_glsl_ext_FFT_twiddleInverse(tid, i, logTwo, dataLength);
305263

306264
vec2 this_value = current_values[t];
307265

308266
if(0 < uint(tid & mask)) {
309-
current_values[t] = shuffled_value + nbl_glsl_complex_mul(twiddle, this_value);
267+
current_values[t] = shuffled_value + nbl_glsl_complex_mul(nbl_glsl_ext_FFT_twiddle, this_value);
310268
} else {
311-
current_values[t] = this_value + nbl_glsl_complex_mul(twiddle, shuffled_value);
269+
current_values[t] = this_value + nbl_glsl_complex_mul(nbl_glsl_ext_FFT_twiddle, shuffled_value);
312270
}
313271
}
314272
}
315273

316274
for(uint t = 0; t < num_virtual_threads; t++)
317275
{
318276
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_BLOCK_SIZE_X_DEFINED_;
319-
uvec3 coords = getCoordinates(tid);
277+
uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
320278
vec2 complex_value = (0 == pc.is_inverse)
321279
? current_values[t]
322280
: current_values[t] / dataLength;

0 commit comments

Comments
 (0)