Skip to content

Commit 2375c47

Browse files
take a step toward not doing stupid reads in the convolution
1 parent 48e514b commit 2375c47

File tree

3 files changed

+22
-32
lines changed

3 files changed

+22
-32
lines changed

examples_tests/49.ComputeFFT/fft_convolve_ifft.comp

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ layout(set=0, binding=1) restrict readonly buffer KernelBuffer
2626
// Get/Set Data Function
2727
layout(push_constant) uniform PushConstants
2828
{
29-
layout (offset = 0) nbl_glsl_ext_FFT_Parameters_t params;
29+
nbl_glsl_ext_FFT_Parameters_t params;
3030
} pc;
3131

32-
nbl_glsl_ext_FFT_Parameters_t nbl_glsl_ext_FFT_getParameters() {
32+
nbl_glsl_ext_FFT_Parameters_t nbl_glsl_ext_FFT_getParameters()
33+
{
3334
nbl_glsl_ext_FFT_Parameters_t ret;
3435
ret = pc.params;
3536
return ret;
@@ -67,24 +68,18 @@ nbl_glsl_complex nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint cha
6768
return nbl_glsl_ext_FFT_getData(clamped_coord, channel);
6869
}
6970

70-
void convolve()
71+
void convolve(in uint ch)
7172
{
72-
uint numChannels = nbl_glsl_ext_FFT_Parameters_t_getNumChannels();
7373
uvec3 dimension = nbl_glsl_ext_FFT_Parameters_t_getDimensions();
74-
uint dataLength = nbl_glsl_ext_FFT_Parameters_t_getFFTLength();
75-
76-
uint thread_offset = gl_LocalInvocationIndex;
77-
uint num_virtual_threads = (dataLength-1u)/(_NBL_GLSL_WORKGROUP_SIZE_)+1u;
78-
79-
for(uint ch = 0u; ch < numChannels; ++ch) {
80-
for(uint t = 0u; t < num_virtual_threads; t++)
81-
{
82-
uint tid = thread_offset + t * _NBL_GLSL_WORKGROUP_SIZE_;
83-
uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
84-
uint idx = ch * (dimension.x * dimension.y * dimension.z) + coords.z * (dimension.x * dimension.y) + coords.y * (dimension.x) + coords.x;
85-
nbl_glsl_complex temp = inoutData[idx];
86-
inoutData[idx] = nbl_glsl_complex_mul(temp, kerData[idx]);
87-
}
74+
75+
const uint item_per_thread_count = nbl_glsl_ext_FFT_Parameters_t_getFFTLength()>>_NBL_GLSL_WORKGROUP_SIZE_LOG2_;
76+
for(uint t = 0u; t < item_per_thread_count; t++)
77+
{
78+
uint tid = gl_LocalInvocationIndex + t * _NBL_GLSL_WORKGROUP_SIZE_;
79+
uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
80+
uint idx = ch * (dimension.x * dimension.y * dimension.z) + coords.z * (dimension.x * dimension.y) + coords.y * (dimension.x) + coords.x;
81+
nbl_glsl_complex temp = inoutData[idx];
82+
inoutData[idx] = nbl_glsl_complex_mul(temp, kerData[idx]);
8883
}
8984
}
9085

@@ -93,17 +88,12 @@ void main()
9388
const uint numChannels = nbl_glsl_ext_FFT_Parameters_t_getNumChannels();
9489
for(uint ch = 0u; ch < numChannels; ++ch)
9590
{
96-
nbl_glsl_ext_FFT(nbl_glsl_ext_FFT_Parameters_t_getIsInverse(), ch);
97-
}
98-
99-
barrier();
91+
nbl_glsl_ext_FFT(false, ch);
92+
barrier();
10093

101-
convolve(); // inoutData+kerData->inoutData
94+
convolve(ch); // inoutData+kerData->inoutData
10295

103-
barrier();
104-
105-
for(uint ch = 0u; ch < numChannels; ++ch)
106-
{
107-
nbl_glsl_ext_FFT(!nbl_glsl_ext_FFT_Parameters_t_getIsInverse(), ch); // inoutData->inoutData
96+
barrier();
97+
nbl_glsl_ext_FFT(true, ch); // inoutData->inoutData
10898
}
10999
}

examples_tests/49.ComputeFFT/last_fft.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ layout(set=0, binding=1, rgba16f) uniform image2D outImage;
2424

2525
layout(push_constant) uniform PushConstants
2626
{
27-
layout (offset = 0) nbl_glsl_ext_FFT_Parameters_t params;
27+
nbl_glsl_ext_FFT_Parameters_t params;
2828
layout (offset = 32) uvec3 kernel_dimension;
2929
} pc;
3030

@@ -72,6 +72,6 @@ void main()
7272
const uint numChannels = nbl_glsl_ext_FFT_Parameters_t_getNumChannels();
7373
for(uint ch = 0u; ch < numChannels; ++ch)
7474
{
75-
nbl_glsl_ext_FFT(nbl_glsl_ext_FFT_Parameters_t_getIsInverse(), ch);
75+
nbl_glsl_ext_FFT(true, ch);
7676
}
7777
}

include/nbl/builtin/glsl/ext/FFT/normalization.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ layout(local_size_x=256, local_size_y=1, local_size_z=1) in;
55

66
layout(set=0, binding=0) restrict readonly buffer InBuffer
77
{
8-
complex_value in_data[];
8+
nbl_glsl_complex in_data[];
99
};
1010

1111
layout(set=0, binding=1) restrict buffer OutBuffer
1212
{
13-
complex_value out_data[];
13+
nbl_glsl_complex out_data[];
1414
};
1515

1616
void main()

0 commit comments

Comments
 (0)