Skip to content

Commit 32323bc

Browse files
get rid of stupid triple roundtrip in convolutiion
1 parent 2375c47 commit 32323bc

File tree

3 files changed

+46
-48
lines changed

3 files changed

+46
-48
lines changed

examples_tests/49.ComputeFFT/fft_convolve_ifft.comp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,32 +68,46 @@ nbl_glsl_complex nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint cha
6868
return nbl_glsl_ext_FFT_getData(clamped_coord, channel);
6969
}
7070

71-
void convolve(in uint ch)
71+
void convolve(in uint item_per_thread_count, in uint ch)
7272
{
73+
// TODO: decouple kernel size from image size
7374
uvec3 dimension = nbl_glsl_ext_FFT_Parameters_t_getDimensions();
74-
75-
const uint item_per_thread_count = nbl_glsl_ext_FFT_Parameters_t_getFFTLength()>>_NBL_GLSL_WORKGROUP_SIZE_LOG2_;
76-
for(uint t = 0u; t < item_per_thread_count; t++)
75+
76+
for(uint t=0u; t<item_per_thread_count; t++)
7777
{
7878
uint tid = gl_LocalInvocationIndex + t * _NBL_GLSL_WORKGROUP_SIZE_;
7979
uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
80+
//coords &= uvec3(0xffeu);
8081
uint idx = ch * (dimension.x * dimension.y * dimension.z) + coords.z * (dimension.x * dimension.y) + coords.y * (dimension.x) + coords.x;
81-
nbl_glsl_complex temp = inoutData[idx];
82-
inoutData[idx] = nbl_glsl_complex_mul(temp, kerData[idx]);
82+
nbl_glsl_ext_FFT_impl_values[t] = nbl_glsl_complex_mul(nbl_glsl_ext_FFT_impl_values[t],kerData[idx]);
8383
}
8484
}
8585

8686
void main()
8787
{
88+
const uint dataLength = nbl_glsl_ext_FFT_Parameters_t_getFFTLength();
89+
const uint item_per_thread_count = dataLength>>_NBL_GLSL_WORKGROUP_SIZE_LOG2_;
8890
const uint numChannels = nbl_glsl_ext_FFT_Parameters_t_getNumChannels();
8991
for(uint ch = 0u; ch < numChannels; ++ch)
9092
{
91-
nbl_glsl_ext_FFT(false, ch);
93+
// Load Values into local memory
94+
for(uint t=0u; t<item_per_thread_count; t++)
95+
{
96+
const uint tid = (t<<_NBL_GLSL_WORKGROUP_SIZE_LOG2_)|gl_LocalInvocationIndex;
97+
nbl_glsl_ext_FFT_impl_values[t] = nbl_glsl_ext_FFT_getPaddedData(nbl_glsl_ext_FFT_getCoordinates(tid),ch);
98+
}
99+
nbl_glsl_ext_FFT_preloaded(false,dataLength);
92100
barrier();
93101

94-
convolve(ch); // inoutData+kerData->inoutData
102+
convolve(item_per_thread_count,ch);
95103

96104
barrier();
97-
nbl_glsl_ext_FFT(true, ch); // inoutData->inoutData
105+
nbl_glsl_ext_FFT_preloaded(true,dataLength);
106+
// write out to main memory
107+
for(uint t=0u; t<item_per_thread_count; t++)
108+
{
109+
const uint tid = (t<<_NBL_GLSL_WORKGROUP_SIZE_LOG2_)|gl_LocalInvocationIndex;
110+
nbl_glsl_ext_FFT_setData(nbl_glsl_ext_FFT_getCoordinates(tid),ch,nbl_glsl_ext_FFT_impl_values[t]);
111+
}
98112
}
99113
}

include/nbl/builtin/glsl/ext/FFT/fft.glsl

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,23 +84,12 @@ void nbl_glsl_ext_FFT_loop(in bool is_inverse, in uint virtual_thread_count, in
8484
}
8585
}
8686

87-
void nbl_glsl_ext_FFT(bool is_inverse, uint channel)
87+
void nbl_glsl_ext_FFT_preloaded(bool is_inverse, in uint dataLength)
8888
{
8989
// Virtual Threads Calculation
90-
const uint dataLength = nbl_glsl_ext_FFT_Parameters_t_getFFTLength();
91-
const uint item_per_thread_count = dataLength>>_NBL_GLSL_WORKGROUP_SIZE_LOG2_;
92-
9390
const uint halfDataLength = dataLength>>1u;
94-
const uint virtual_thread_count = item_per_thread_count>>1u;
91+
const uint virtual_thread_count = halfDataLength>>_NBL_GLSL_WORKGROUP_SIZE_LOG2_;
9592

96-
// Load Values into local memory
97-
for(uint t=0u; t<item_per_thread_count; t++)
98-
{
99-
const uint tid = (t<<_NBL_GLSL_WORKGROUP_SIZE_LOG2_)|gl_LocalInvocationIndex;
100-
nbl_glsl_ext_FFT_impl_values[t] = nbl_glsl_ext_FFT_getPaddedData(nbl_glsl_ext_FFT_getCoordinates(tid),channel);
101-
if (is_inverse)
102-
nbl_glsl_ext_FFT_impl_values[t] /= float(virtual_thread_count);
103-
}
10493
// special forward steps
10594
if (!is_inverse)
10695
for (uint step=halfDataLength; step>_NBL_GLSL_WORKGROUP_SIZE_; step>>=1u)
@@ -111,11 +100,32 @@ void nbl_glsl_ext_FFT(bool is_inverse, uint channel)
111100
const uint lo_ix = t<<1u;
112101
const uint hi_ix = lo_ix|1u;
113102
nbl_glsl_workgroupFFT(is_inverse,nbl_glsl_ext_FFT_impl_values[lo_ix],nbl_glsl_ext_FFT_impl_values[hi_ix]);
103+
if (is_inverse)
104+
{
105+
nbl_glsl_ext_FFT_impl_values[lo_ix] /= float(virtual_thread_count);
106+
nbl_glsl_ext_FFT_impl_values[hi_ix] /= float(virtual_thread_count);
107+
}
114108
}
115109
// special inverse steps
116110
if (is_inverse)
117111
for (uint step=_NBL_GLSL_WORKGROUP_SIZE_<<1u; step<dataLength; step<<=1u)
118112
nbl_glsl_ext_FFT_loop(true,virtual_thread_count,step);
113+
}
114+
115+
void nbl_glsl_ext_FFT(bool is_inverse, uint channel)
116+
{
117+
// Virtual Threads Calculation
118+
const uint dataLength = nbl_glsl_ext_FFT_Parameters_t_getFFTLength();
119+
const uint item_per_thread_count = dataLength>>_NBL_GLSL_WORKGROUP_SIZE_LOG2_;
120+
121+
// Load Values into local memory
122+
for(uint t=0u; t<item_per_thread_count; t++)
123+
{
124+
const uint tid = (t<<_NBL_GLSL_WORKGROUP_SIZE_LOG2_)|gl_LocalInvocationIndex;
125+
nbl_glsl_ext_FFT_impl_values[t] = nbl_glsl_ext_FFT_getPaddedData(nbl_glsl_ext_FFT_getCoordinates(tid),channel);
126+
}
127+
// do FFT
128+
nbl_glsl_ext_FFT_preloaded(is_inverse,dataLength);
119129
// write out to main memory
120130
for(uint t=0u; t<item_per_thread_count; t++)
121131
{

include/nbl/builtin/glsl/math/complex.glsl

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,6 @@ nbl_glsl_complex nbl_glsl_complex_conjugate(in nbl_glsl_complex complex) {
4343

4444

4545
// FFT
46-
nbl_glsl_complex nbl_glsl_FFT_half_twiddle(in uint k, in float N)
47-
{
48-
const float arg = -2.f*nbl_glsl_PI*float(k)/N;
49-
nbl_glsl_complex retval;
50-
retval.x = cos(arg);
51-
retval.y = sqrt(1.f-retval.x*retval.x); // twiddle is always half the range, so no conditional -1.f needed
52-
return retval;
53-
}
54-
nbl_glsl_complex nbl_glsl_FFT_half_twiddle(in uint k, in uint logTwoN)
55-
{
56-
return nbl_glsl_FFT_half_twiddle(k,float(1<<logTwoN));
57-
}
58-
59-
nbl_glsl_complex nbl_glsl_FFT_half_twiddle(in bool is_inverse, in uint k, in float N)
60-
{
61-
nbl_glsl_complex twiddle = nbl_glsl_FFT_half_twiddle(k,N);
62-
if (is_inverse)
63-
return nbl_glsl_complex_conjugate(twiddle);
64-
return twiddle;
65-
}
66-
nbl_glsl_complex nbl_glsl_FFT_half_twiddle(in bool is_inverse, in uint k, in uint logTwoN)
67-
{
68-
return nbl_glsl_FFT_half_twiddle(is_inverse,k,float(1<<logTwoN));
69-
}
70-
71-
7246
nbl_glsl_complex nbl_glsl_FFT_twiddle(in uint k, in float N)
7347
{
7448
nbl_glsl_complex retval;

0 commit comments

Comments
 (0)