Skip to content

Commit 1101988

Browse files
cut down the dispatch sizes (exploiting texture wrapping)
1 parent c39759b commit 1101988

File tree

4 files changed

+72
-37
lines changed

4 files changed

+72
-37
lines changed

examples_tests/49.ComputeFFT/fft_convolve_ifft.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void main()
5050
{
5151
const uint tid = (t<<_NBL_GLSL_WORKGROUP_SIZE_LOG2_)|gl_LocalInvocationIndex;
5252
const uint trueDim = nbl_glsl_ext_FFT_Parameters_t_getDimensions()[nbl_glsl_ext_FFT_Parameters_t_getDirection()];
53-
nbl_glsl_ext_FFT_impl_values[t] = nbl_glsl_ext_FFT_getPaddedData(nbl_glsl_ext_FFT_getPaddedCoordinates(tid,log2FFTSize,trueDim),ch);
53+
nbl_glsl_ext_FFT_impl_values[t] = nbl_glsl_ext_FFT_getPaddedData(ivec3(nbl_glsl_ext_FFT_getCoordinates(tid)),ch);
5454
}
5555
nbl_glsl_ext_FFT_preloaded(false,log2FFTSize);
5656
barrier();

examples_tests/49.ComputeFFT/last_fft.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ layout(set=0, binding=1, rgba16f) uniform image2D outImage;
88
void nbl_glsl_ext_FFT_setData(in uvec3 coordinate, in uint channel, in nbl_glsl_complex complex_value)
99
{
1010
// TODO PC
11-
const ivec2 padding = imageSize(outImage).x!=512u ? ivec2(384,664):ivec2(0);
11+
const ivec2 padding = imageSize(outImage).x!=512u ? ivec2(384,/*664*/0):ivec2(0);
1212
const ivec2 coords = ivec2(coordinate.xy)-padding;
1313

1414
if (all(lessThanEqual(ivec2(0),coords))&&all(greaterThan(imageSize(outImage),coords)))

examples_tests/49.ComputeFFT/main.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -614,17 +614,11 @@ int main()
614614
const ISampler::E_TEXTURE_CLAMP fftPadding[2] = {ISampler::ETC_MIRROR,ISampler::ETC_MIRROR};
615615
const auto passes = FFTClass::buildParameters(false,srcNumChannels,srcDim,fftPushConstants,fftDispatchInfo,fftPadding,paddedSrcDim);
616616
{
617-
fftPushConstants[1].input_dimensions.x = 2048u;
618-
fftPushConstants[1].input_strides = fftPushConstants[0].output_strides;
619-
fftPushConstants[1].output_strides.x = 2048u;
620-
fftPushConstants[1].output_strides.y = 1u;
621-
fftPushConstants[2] = fftPushConstants[0];
622-
fftPushConstants[2].input_dimensions.x = 2048u;
623-
fftPushConstants[2].input_dimensions.y = 2048u;
617+
fftPushConstants[2].input_dimensions = fftPushConstants[1].input_dimensions;
624618
{
625-
fftPushConstants[2].input_dimensions.w ^= 0x80000000u;
626-
fftPushConstants[2].input_dimensions.w &= 0xfffffffdu;
619+
fftPushConstants[2].input_dimensions.w = fftPushConstants[0].input_dimensions.w^0x80000000u;
627620
fftPushConstants[2].input_strides = fftPushConstants[1].output_strides;
621+
fftPushConstants[2].output_strides = fftPushConstants[0].input_strides;
628622
}
629623
fftDispatchInfo[2] = fftDispatchInfo[0];
630624
}

include/nbl/ext/FFT/FFT.h

Lines changed: 67 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ namespace FFT
1818
{
1919

2020
typedef uint32_t uint;
21-
struct alignas(16) uvec3 {
21+
struct alignas(16) uvec3
22+
{
2223
uint x,y,z;
2324
};
2425
struct alignas(16) uvec4 {
@@ -50,45 +51,85 @@ class FFT final : public core::IReferenceCounted
5051
{
5152
uint32_t passesRequired = 0u;
5253

54+
const auto paddedInputDimensions = padDimensions(extraPaddedInputDimensions);
55+
56+
using SizeAxisPair = std::tuple<uint32_t,uint8_t,uint8_t>;
57+
std::array<SizeAxisPair,3u> passes;
5358
if (numChannels)
5459
{
55-
const auto paddedInputDimensions = padDimensions(extraPaddedInputDimensions);
5660
for (uint32_t i=0u; i<3u; i++)
57-
if ((&inputDimensions.width)[i]>1u)
5861
{
59-
// TODO: rework
60-
auto& dispatch = outInfos[passesRequired];
61-
dispatch.workGroupCount[0] = paddedInputDimensions.width;
62-
dispatch.workGroupCount[1] = paddedInputDimensions.height;
63-
dispatch.workGroupCount[2] = paddedInputDimensions.depth;
64-
dispatch.workGroupCount[i] = 1u;
65-
66-
auto& params = outParams[passesRequired];
67-
params.input_dimensions.x = inputDimensions.width;
68-
params.input_dimensions.y = inputDimensions.height;
69-
params.input_dimensions.z = inputDimensions.depth;
62+
auto dim = (&paddedInputDimensions.width)[i];
63+
if (dim<2u)
64+
continue;
65+
passes[passesRequired++] = {dim,i,paddingType[i]};
66+
}
67+
std::sort(passes.begin(),passes.begin()+passesRequired,[](const auto& lhs, const auto& rhs)->bool{return std::get<0u>(lhs)>std::get<0u>(rhs);});
68+
}
69+
70+
auto computeOutputStride = [](const uvec3& output_dimensions, const auto axis, const auto nextAxis) -> uvec4
71+
{
72+
// coord[axis] = 1u
73+
// coord[nextAxis] = fftLen;
74+
// coord[otherAxis] = fftLen*dimension[nextAxis];
75+
uvec4 stride;
76+
stride.w = output_dimensions.x*output_dimensions.y*output_dimensions.z;
77+
for (auto i=0u; i<3u; i++)
78+
{
79+
auto& coord = (&stride.x)[i];
80+
if (i!=axis)
81+
{
82+
coord = (&output_dimensions.x)[axis];
83+
if (i!=nextAxis)
84+
coord *= (&output_dimensions.x)[nextAxis];
85+
}
86+
else
87+
coord = 1u;
88+
}
89+
return stride;
90+
};
91+
92+
if (passesRequired)
93+
{
94+
uvec3 output_dimensions = {inputDimensions.width,inputDimensions.height,inputDimensions.depth};
95+
for (uint32_t i=0u; i<passesRequired; i++)
96+
{
97+
auto& params = outParams[i];
98+
params.input_dimensions.x = output_dimensions.x;
99+
params.input_dimensions.y = output_dimensions.y;
100+
params.input_dimensions.z = output_dimensions.z;
101+
102+
const auto paddedAxisLen = std::get<0u>(passes[i]);
70103
{
71-
const uint32_t fftSize = (&paddedInputDimensions.width)[i];
72104
assert(paddingType[i]<=asset::ISampler::ETC_MIRROR);
73105
params.input_dimensions.w = (isInverse ? 0x80000000u:0x0u)|
74106
(i<<28u)| // direction
75107
((numChannels-1u)<<26u)| // max channel
76-
(core::findMSB(fftSize)<<3u)| // log2(fftSize)
77-
uint32_t(paddingType[i]);
108+
(core::findMSB(paddedAxisLen)<<3u)| // log2(fftSize)
109+
uint32_t(std::get<2u>(passes[i]));
78110
}
79-
params.input_strides.x = 1u;
80-
params.input_strides.y = paddedInputDimensions.width;
81-
params.input_strides.z = params.input_strides.y*paddedInputDimensions.height;
82-
params.input_strides.w = params.input_strides.z*paddedInputDimensions.depth;
83-
params.output_strides = params.input_strides;
84111

85-
passesRequired++;
112+
const auto passAxis = std::get<1u>(passes[i]);
113+
(&output_dimensions.x)[passAxis] = paddedAxisLen;
114+
if (i)
115+
params.input_strides = outParams[i-1u].output_strides;
116+
else // TODO provide an override for input strides
117+
{
118+
params.input_strides.x = 1u;
119+
params.input_strides.y = inputDimensions.width;
120+
params.input_strides.z = params.input_strides.y * inputDimensions.height;
121+
params.input_strides.w = params.input_strides.z * inputDimensions.depth;
122+
}
123+
params.output_strides = computeOutputStride(output_dimensions,passAxis,std::get<1u>(passes[(i+1u)%passesRequired]));
124+
125+
auto& dispatch = outInfos[i];
126+
dispatch.workGroupCount[0] = output_dimensions.x;
127+
dispatch.workGroupCount[1] = output_dimensions.y;
128+
dispatch.workGroupCount[2] = output_dimensions.z;
129+
dispatch.workGroupCount[passAxis] = 1u;
86130
}
87131
}
88132

89-
if (passesRequired)
90-
outParams[passesRequired-1u].output_strides = outParams[0].input_strides;
91-
92133
return passesRequired;
93134
}
94135
static inline uint32_t buildParameters(

0 commit comments

Comments
 (0)