Skip to content

Commit e3f3fc5

Browse files
committed
fft_convolve_ifft shader start
+ cleanups
1 parent 92709c5 commit e3f3fc5

File tree

4 files changed

+98
-15
lines changed

4 files changed

+98
-15
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// WorkGroup Size
2+
3+
#define _NBL_GLSL_WORKGROUP_SIZE_ 256
4+
5+
layout(local_size_x=_NBL_GLSL_WORKGROUP_SIZE_, local_size_y=1, local_size_z=1) in;
6+
7+
8+
#define _NBL_GLSL_EXT_FFT_GET_PARAMETERS_DEFINED_
9+
#define _NBL_GLSL_EXT_FFT_GET_DATA_DEFINED_
10+
#define _NBL_GLSL_EXT_FFT_SET_DATA_DEFINED_
11+
#define _NBL_GLSL_EXT_FFT_GET_PADDED_DATA_DEFINED_
12+
#include "nbl/builtin/glsl/ext/FFT/fft.glsl"
13+
14+
// Input Descriptor
15+
16+
layout(set=0, binding=0) readonly restrict buffer InputBuffer
17+
{
18+
nbl_glsl_complex inData[];
19+
};
20+
21+
layout(set=0, binding=1) restrict readonly buffer KernelBuffer
22+
{
23+
nbl_glsl_complex kerData[];
24+
};
25+
26+
// Output Descriptor
27+
28+
layout(set=0, binding=2) restrict buffer OutputBuffer
29+
{
30+
nbl_glsl_complex outData[];
31+
};
32+
33+
// Get/Set Data Function
34+
layout(push_constant) uniform PushConstants
35+
{
36+
layout (offset = 0) nbl_glsl_ext_FFT_Parameters_t params;
37+
} pc;
38+
39+
nbl_glsl_ext_FFT_Parameters_t nbl_glsl_ext_FFT_getParameters() {
40+
nbl_glsl_ext_FFT_Parameters_t ret;
41+
ret = pc.params;
42+
return ret;
43+
}
44+
45+
nbl_glsl_complex nbl_glsl_ext_FFT_getData(in uvec3 coordinate, in uint channel)
46+
{
47+
nbl_glsl_complex retValue = nbl_glsl_complex(0, 0);
48+
uvec3 dimension = nbl_glsl_ext_FFT_getDimensions();
49+
uint index = channel * (dimension.x * dimension.y * dimension.z) + coordinate.z * (dimension.x * dimension.y) + coordinate.y * (dimension.x) + coordinate.x;
50+
retValue = inData[index];
51+
return retValue;
52+
}
53+
54+
void nbl_glsl_ext_FFT_setData(in uvec3 coordinate, in uint channel, in nbl_glsl_complex complex_value)
55+
{
56+
uvec3 dimension = nbl_glsl_ext_FFT_getPaddedDimensions();
57+
uint index = channel * (dimension.x * dimension.y * dimension.z) + coordinate.z * (dimension.x * dimension.y) + coordinate.y * (dimension.x) + coordinate.x;
58+
outData[index] = complex_value;
59+
}
60+
61+
nbl_glsl_complex nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint channel) {
62+
63+
uvec3 max_coord = nbl_glsl_ext_FFT_getDimensions() - uvec3(1u);
64+
uvec3 clamped_coord = min(coordinate, max_coord);
65+
66+
bool is_out_of_range = any(bvec3(coordinate!=clamped_coord));
67+
68+
uint paddingType = nbl_glsl_ext_FFT_getPaddingType();
69+
70+
if (_NBL_GLSL_EXT_FFT_FILL_WITH_ZERO_ == paddingType && is_out_of_range) {
71+
return nbl_glsl_complex(0, 0);
72+
}
73+
74+
return nbl_glsl_ext_FFT_getData(clamped_coord, channel);
75+
}
76+
77+
void convolve()
78+
{
79+
uint idx = 0;
80+
inData[idx] = nbl_glsl_complex_mul(outData[idx], kerData[idx]);
81+
}
82+
83+
void main()
84+
{
85+
nbl_glsl_ext_FFT(nbl_glsl_ext_FFT_getIsInverse()); // inData->outData
86+
87+
barrier();
88+
memoryBarrierShared();
89+
90+
convolve(); // outData+kerData->inData
91+
92+
barrier();
93+
memoryBarrierShared();
94+
95+
nbl_glsl_ext_FFT(!nbl_glsl_ext_FFT_getIsInverse()); // inData->outData
96+
}

include/nbl/builtin/glsl/ext/FFT/default_compute_fft.comp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,14 @@ layout(set=_NBL_GLSL_EXT_FFT_OUTPUT_SET_DEFINED_, binding=_NBL_GLSL_EXT_FFT_OUTP
6262

6363
// Get/Set Data Function
6464

65-
//TODO: investigate why putting this uint between the 2 other uvec3's don't work
6665
layout(push_constant) uniform PushConstants
6766
{
6867
layout (offset = 0) nbl_glsl_ext_FFT_Parameters_t params;
69-
// layout (offset = 0) uvec3 dimension;
70-
// layout (offset = 16) uvec3 padded_dimension;
71-
// layout (offset = 32) uint direction_isInverse_paddingType; // packed into a uint
7268
} pc;
7369

7470
nbl_glsl_ext_FFT_Parameters_t nbl_glsl_ext_FFT_getParameters() {
7571
nbl_glsl_ext_FFT_Parameters_t ret;
7672
ret = pc.params;
77-
// ret.dimension = pc.dimension;
78-
// ret.direction_isInverse_paddingType = pc.direction_isInverse_paddingType;
79-
// ret.padded_dimension = pc.padded_dimension;
8073
return ret;
8174
}
8275

@@ -119,7 +112,7 @@ nbl_glsl_complex nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint cha
119112

120113
void main()
121114
{
122-
nbl_glsl_ext_FFT();
115+
nbl_glsl_ext_FFT(nbl_glsl_ext_FFT_getIsInverse());
123116
}
124117

125118
#endif

include/nbl/builtin/glsl/ext/FFT/fft.glsl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ uint nbl_glsl_ext_FFT_getDimLength(uvec3 dimension)
147147
return dimension[direction];
148148
}
149149

150-
void nbl_glsl_ext_FFT()
150+
void nbl_glsl_ext_FFT(bool is_inverse)
151151
{
152152
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
153153
// Virtual Threads Calculation
@@ -157,8 +157,6 @@ void nbl_glsl_ext_FFT()
157157

158158
uint channel = nbl_glsl_ext_FFT_getChannel();
159159

160-
bool is_inverse = nbl_glsl_ext_FFT_getIsInverse();
161-
162160
// Pass 0: Bit Reversal
163161
uint leadingZeroes = nbl_glsl_clz(dataLength) + 1u;
164162
uint logTwo = 32u - leadingZeroes;

include/nbl/ext/FFT/FFT.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,15 +226,11 @@ class FFT : public core::TotalInterface
226226
params.dimension.y = inputDimension.height;
227227
params.dimension.z = inputDimension.depth;
228228
params.dimension.w = packed;
229-
// params.direction_isInverse_paddingType = packed;
230229
params.padded_dimension.x = paddedInputDimension.width;
231230
params.padded_dimension.y = paddedInputDimension.height;
232231
params.padded_dimension.z = paddedInputDimension.depth;
233232

234233
driver->pushConstants(pipelineLayout, nbl::video::IGPUSpecializedShader::ESS_COMPUTE, 0u, sizeof(Parameters_t), &params);
235-
// driver->pushConstants(pipelineLayout, nbl::video::IGPUSpecializedShader::ESS_COMPUTE, 0u, sizeof(uint32_t) * 3, &inputDimension);
236-
// driver->pushConstants(pipelineLayout, nbl::video::IGPUSpecializedShader::ESS_COMPUTE, sizeof(uint32_t) * 4, sizeof(uint32_t) * 3, &paddedInputDimension);
237-
// driver->pushConstants(pipelineLayout, nbl::video::IGPUSpecializedShader::ESS_COMPUTE, sizeof(uint32_t) * 8, sizeof(uint32_t), &packed);
238234
}
239235

240236
// Kernel Normalization

0 commit comments

Comments
 (0)