@@ -26,10 +26,11 @@ layout(set=0, binding=1) restrict readonly buffer KernelBuffer
26
26
// Get/Set Data Function
27
27
layout(push_constant) uniform PushConstants
28
28
{
29
- layout (offset = 0) nbl_glsl_ext_FFT_Parameters_t params;
29
+ nbl_glsl_ext_FFT_Parameters_t params;
30
30
} pc;
31
31
32
- nbl_glsl_ext_FFT_Parameters_t nbl_glsl_ext_FFT_getParameters() {
32
+ nbl_glsl_ext_FFT_Parameters_t nbl_glsl_ext_FFT_getParameters()
33
+ {
33
34
nbl_glsl_ext_FFT_Parameters_t ret;
34
35
ret = pc.params;
35
36
return ret;
@@ -67,24 +68,18 @@ nbl_glsl_complex nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint cha
67
68
return nbl_glsl_ext_FFT_getData(clamped_coord, channel);
68
69
}
69
70
70
- void convolve()
71
+ void convolve(in uint ch )
71
72
{
72
- uint numChannels = nbl_glsl_ext_FFT_Parameters_t_getNumChannels();
73
73
uvec3 dimension = nbl_glsl_ext_FFT_Parameters_t_getDimensions();
74
- uint dataLength = nbl_glsl_ext_FFT_Parameters_t_getFFTLength();
75
-
76
- uint thread_offset = gl_LocalInvocationIndex;
77
- uint num_virtual_threads = (dataLength-1u)/(_NBL_GLSL_WORKGROUP_SIZE_)+1u;
78
-
79
- for(uint ch = 0u; ch < numChannels; ++ch) {
80
- for(uint t = 0u; t < num_virtual_threads; t++)
81
- {
82
- uint tid = thread_offset + t * _NBL_GLSL_WORKGROUP_SIZE_;
83
- uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
84
- uint idx = ch * (dimension.x * dimension.y * dimension.z) + coords.z * (dimension.x * dimension.y) + coords.y * (dimension.x) + coords.x;
85
- nbl_glsl_complex temp = inoutData[idx];
86
- inoutData[idx] = nbl_glsl_complex_mul(temp, kerData[idx]);
87
- }
74
+
75
+ const uint item_per_thread_count = nbl_glsl_ext_FFT_Parameters_t_getFFTLength()>>_NBL_GLSL_WORKGROUP_SIZE_LOG2_;
76
+ for(uint t = 0u; t < item_per_thread_count; t++)
77
+ {
78
+ uint tid = gl_LocalInvocationIndex + t * _NBL_GLSL_WORKGROUP_SIZE_;
79
+ uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
80
+ uint idx = ch * (dimension.x * dimension.y * dimension.z) + coords.z * (dimension.x * dimension.y) + coords.y * (dimension.x) + coords.x;
81
+ nbl_glsl_complex temp = inoutData[idx];
82
+ inoutData[idx] = nbl_glsl_complex_mul(temp, kerData[idx]);
88
83
}
89
84
}
90
85
@@ -93,17 +88,12 @@ void main()
93
88
const uint numChannels = nbl_glsl_ext_FFT_Parameters_t_getNumChannels();
94
89
for(uint ch = 0u; ch < numChannels; ++ch)
95
90
{
96
- nbl_glsl_ext_FFT(nbl_glsl_ext_FFT_Parameters_t_getIsInverse(), ch);
97
- }
98
-
99
- barrier();
91
+ nbl_glsl_ext_FFT(false, ch);
92
+ barrier();
100
93
101
- convolve(); // inoutData+kerData->inoutData
94
+ convolve(ch ); // inoutData+kerData->inoutData
102
95
103
- barrier();
104
-
105
- for(uint ch = 0u; ch < numChannels; ++ch)
106
- {
107
- nbl_glsl_ext_FFT(!nbl_glsl_ext_FFT_Parameters_t_getIsInverse(), ch); // inoutData->inoutData
96
+ barrier();
97
+ nbl_glsl_ext_FFT(true, ch); // inoutData->inoutData
108
98
}
109
99
}
0 commit comments