Skip to content

Commit 5d89705

Browse files
nasty parameter refactor
1 parent 90f1307 commit 5d89705

File tree

6 files changed

+106
-123
lines changed

6 files changed

+106
-123
lines changed

examples_tests/49.ComputeFFT/fft_convolve_ifft.comp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,8 @@ void main()
4444
{
4545
const uint log2FFTSize = nbl_glsl_ext_FFT_Parameters_t_getLog2FFTSize();
4646
const uint item_per_thread_count = 0x1u<<(log2FFTSize-_NBL_GLSL_WORKGROUP_SIZE_LOG2_);
47-
48-
const uint numChannels = nbl_glsl_ext_FFT_Parameters_t_getNumChannels();
49-
for(uint ch = 0u; ch < numChannels; ++ch)
47+
48+
for(uint ch=0u; ch<=nbl_glsl_ext_FFT_Parameters_t_getMaxChannel(); ++ch)
5049
{
5150
// Load Values into local memory
5251
for(uint t=0u; t<item_per_thread_count; t++)

examples_tests/49.ComputeFFT/main.cpp

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,6 @@ using namespace nbl::video;
2121
#include "nbl/core/math/intutil.h"
2222
#include "nbl/core/math/glslFunctions.h"
2323

24-
struct DispatchInfo_t
25-
{
26-
uint32_t workGroupDims[3];
27-
uint32_t workGroupCount[3];
28-
};
29-
3024
constexpr uint32_t channelCountOverride = 3u;
3125

3226
inline core::smart_refctd_ptr<video::IGPUSpecializedShader> createShader(
@@ -517,23 +511,24 @@ int main()
517511
return dset;
518512
}();
519513

520-
auto fftDispatchInfo_Horizontal = FFTClass::buildParameters(paddedKerDim, FFTClass::Direction::X);
521-
auto fftDispatchInfo_Vertical = FFTClass::buildParameters(paddedKerDim, FFTClass::Direction::Y);
514+
FFTClass::Parameters_t fftPushConstants[2];
515+
FFTClass::DispatchInfo_t fftDispatchInfo[2];
516+
const FFTClass::PaddingType fftPadding[2] = {FFTClass::PaddingType::FILL_WITH_ZERO,FFTClass::PaddingType::FILL_WITH_ZERO};
517+
const auto passes = FFTClass::buildParameters(false,srcNumChannels,kerDim,fftPushConstants,fftDispatchInfo,fftPadding);
518+
assert(passes==2u);
522519

523520
// Ker Image FFT X
524521
{
525522
auto fftPipeline_ImageInput = driver->createGPUComputePipeline(nullptr,core::smart_refctd_ptr(imageFirstFFTPipelineLayout),createShader(driver, paddedKerDim.width, "../image_first_fft.comp"));
526523
driver->bindComputePipeline(fftPipeline_ImageInput.get());
527524
driver->bindDescriptorSets(EPBP_COMPUTE, imageFirstFFTPipelineLayout.get(), 0u, 1u, &fftDescriptorSet_Ker_FFT_X.get(), nullptr);
528-
FFTClass::pushConstants(driver, imageFirstFFTPipelineLayout.get(), kerDim, paddedKerDim, FFTClass::Direction::X, false, srcNumChannels, FFTClass::PaddingType::FILL_WITH_ZERO);
529-
FFTClass::dispatchHelper(driver, fftDispatchInfo_Horizontal);
525+
FFTClass::dispatchHelper(driver, imageFirstFFTPipelineLayout.get(), fftPushConstants[0], fftDispatchInfo[0]);
530526
}
531527

532528
// Ker Image FFT Y
533529
driver->bindComputePipeline(fftPipeline_SSBOInput.get());
534530
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipeline_SSBOInput->getLayout(), 0u, 1u, &fftDescriptorSet_Ker_FFT_Y.get(), nullptr);
535-
FFTClass::pushConstants(driver, fftPipeline_SSBOInput->getLayout(), paddedKerDim, paddedKerDim, FFTClass::Direction::Y, false, srcNumChannels);
536-
FFTClass::dispatchHelper(driver, fftDispatchInfo_Vertical);
531+
FFTClass::dispatchHelper(driver, fftPipeline_SSBOInput->getLayout(), fftPushConstants[1], fftDispatchInfo[1]);
537532

538533
// Ker Normalization
539534
auto fftPipeline_KernelNormalization = driver->createGPUComputePipeline(nullptr, core::smart_refctd_ptr(fftPipelineLayout_KernelNormalization),
@@ -549,7 +544,7 @@ int main()
549544
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_KernelNormalization.get(), 0u, 1u, &fftDescriptorSet_KernelNormalization.get(), nullptr);
550545
{
551546
NormalizationPushConstants normalizationPC;
552-
normalizationPC.stride = {1u,paddedKerDim.width,paddedKerDim.width*paddedKerDim.height,paddedKerDim.width*paddedKerDim.height}; // TODO: take from the Y FFT pass
547+
normalizationPC.stride = fftPushConstants[1].output_strides;
553548
normalizationPC.bitreverse_shift.x = 32-core::findMSB(paddedKerDim.width);
554549
normalizationPC.bitreverse_shift.y = 32-core::findMSB(paddedKerDim.height);
555550
normalizationPC.bitreverse_shift.z = 0;
@@ -590,31 +585,41 @@ int main()
590585
blitFBO->attach(video::EFAP_COLOR_ATTACHMENT0, std::move(outImgView));
591586

592587

593-
auto fftDispatchInfo_Horizontal = FFTClass::buildParameters(paddedDim, FFTClass::Direction::X);
594-
auto fftDispatchInfo_Vertical = FFTClass::buildParameters(paddedDim, FFTClass::Direction::Y);
588+
FFTClass::Parameters_t fftPushConstants[3];
589+
FFTClass::DispatchInfo_t fftDispatchInfo[3];
590+
const FFTClass::PaddingType fftPadding[2] = {FFTClass::PaddingType::CLAMP_TO_EDGE,FFTClass::PaddingType::CLAMP_TO_EDGE}; // TODO
591+
const auto passes = FFTClass::buildParameters(false,srcNumChannels,srcDim,fftPushConstants,fftDispatchInfo,fftPadding);
592+
{
593+
fftPushConstants[2] = fftPushConstants[0];
594+
{
595+
fftPushConstants[2].input_dimensions.w ^= 0x80000000u;
596+
fftPushConstants[2].input_dimensions.w &= 0xfffffffdu;
597+
}
598+
fftDispatchInfo[2] = fftDispatchInfo[0];
599+
}
600+
assert(passes==2);
601+
595602
while (device->run() && receiver.keepOpen())
596603
{
597604
driver->beginScene(false, false);
598605

599606
// Src Image FFT X
600607
driver->bindComputePipeline(fftPipeline_ImageInput.get());
601608
driver->bindDescriptorSets(EPBP_COMPUTE, imageFirstFFTPipelineLayout.get(), 0u, 1u, &fftDescriptorSet_Src_FFT_X.get(), nullptr);
602-
FFTClass::pushConstants(driver, imageFirstFFTPipelineLayout.get(), srcDim, paddedDim, FFTClass::Direction::X, false, srcNumChannels, FFTClass::PaddingType::CLAMP_TO_EDGE);
603-
FFTClass::dispatchHelper(driver, fftDispatchInfo_Horizontal);
609+
FFTClass::dispatchHelper(driver, imageFirstFFTPipelineLayout.get(), fftPushConstants[0], fftDispatchInfo[0]);
604610

605611
// Src Image FFT Y + Convolution + Convolved IFFT Y
606612
driver->bindComputePipeline(convolvePipeline.get());
607613
driver->bindDescriptorSets(EPBP_COMPUTE, convolvePipeline->getLayout(), 0u, 1u, &convolveDescriptorSet.get(), nullptr);
608-
FFTClass::pushConstants(driver, convolvePipeline->getLayout(), paddedDim, paddedDim, FFTClass::Direction::Y, false, srcNumChannels);
609-
FFTClass::dispatchHelper(driver, fftDispatchInfo_Vertical);
614+
FFTClass::dispatchHelper(driver, convolvePipeline->getLayout(), fftPushConstants[1], fftDispatchInfo[1]);
610615

611616
// Last FFT Padding and Copy to GPU Image
612617
driver->bindComputePipeline(lastFFTPipeline.get());
613618
driver->bindDescriptorSets(EPBP_COMPUTE, lastFFTPipeline->getLayout(), 0u, 1u, &lastFFTDescriptorSet.get(), nullptr);
614-
FFTClass::pushConstants(driver, lastFFTPipeline->getLayout(), paddedDim, paddedDim, FFTClass::Direction::X, true, srcNumChannels);
615-
FFTClass::dispatchHelper(driver, fftDispatchInfo_Horizontal);
619+
FFTClass::dispatchHelper(driver, lastFFTPipeline->getLayout(), fftPushConstants[2], fftDispatchInfo[2]);
616620

617-
if(false == savedToFile) {
621+
if(!savedToFile)
622+
{
618623
savedToFile = true;
619624

620625
core::smart_refctd_ptr<ICPUImageView> imageView;

include/nbl/builtin/glsl/ext/FFT/default_compute_fft.comp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ nbl_glsl_complex nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint cha
100100
#define _NBL_GLSL_EXT_FFT_MAIN_DEFINED_
101101
void main()
102102
{
103-
const uint numChannels = nbl_glsl_ext_FFT_Parameters_t_getNumChannels();
104-
for(uint ch = 0u; ch < numChannels; ++ch)
103+
for(uint ch=0u; ch<=nbl_glsl_ext_FFT_Parameters_t_getMaxChannel(); ++ch)
105104
nbl_glsl_ext_FFT(nbl_glsl_ext_FFT_Parameters_t_getIsInverse(), ch);
106105
}
107106
#endif

include/nbl/builtin/glsl/ext/FFT/parameters.glsl

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,50 +21,49 @@
2121
nbl_glsl_ext_FFT_Parameters_t nbl_glsl_ext_FFT_getParameters();
2222
#endif
2323

24-
uvec3 nbl_glsl_ext_FFT_Parameters_t_getPaddedDimensions() {
25-
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
26-
return (params.padded_dimension.xyz);
27-
}
28-
uvec4 nbl_glsl_ext_FFT_Parameters_t_getOutputStrides()
24+
25+
uvec3 nbl_glsl_ext_FFT_Parameters_t_getDimensions()
2926
{
30-
uvec3 dimension = nbl_glsl_ext_FFT_Parameters_t_getPaddedDimensions();
31-
return uvec4(1u,dimension.x*uvec3(1u,dimension.y*uvec2(1u,dimension.z)));
27+
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
28+
return params.input_dimensions.xyz;
3229
}
3330

34-
uvec3 nbl_glsl_ext_FFT_Parameters_t_getDimensions() {
35-
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
36-
return (params.dimension.xyz);
37-
}
38-
uvec4 nbl_glsl_ext_FFT_Parameters_t_getInputStrides()
31+
bool nbl_glsl_ext_FFT_Parameters_t_getIsInverse()
3932
{
40-
uvec3 dimension = nbl_glsl_ext_FFT_Parameters_t_getDimensions();
41-
return uvec4(1u,dimension.x*uvec3(1u,dimension.y*uvec2(1u,dimension.z)));
33+
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
34+
return bool(params.input_dimensions.w>>31u);
4235
}
43-
uint nbl_glsl_ext_FFT_Parameters_t_getDirection() {
36+
uint nbl_glsl_ext_FFT_Parameters_t_getDirection()
37+
{
4438
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
45-
return (params.dimension.w >> 16) & 0x000000ff;
39+
return (params.input_dimensions.w>>28u)&0x3u;
4640
}
47-
48-
uint nbl_glsl_ext_FFT_Parameters_t_getFFTLength() {
49-
const uint direction = nbl_glsl_ext_FFT_Parameters_t_getDirection();
50-
return nbl_glsl_ext_FFT_Parameters_t_getPaddedDimensions()[direction];
41+
uint nbl_glsl_ext_FFT_Parameters_t_getMaxChannel()
42+
{
43+
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
44+
return (params.input_dimensions.w>>26u)&0x3u;
5145
}
5246
uint nbl_glsl_ext_FFT_Parameters_t_getLog2FFTSize()
5347
{
54-
return findMSB(nbl_glsl_ext_FFT_Parameters_t_getFFTLength());
48+
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
49+
return (params.input_dimensions.w>>3u)&0x1fu;
5550
}
56-
57-
bool nbl_glsl_ext_FFT_Parameters_t_getIsInverse() {
51+
uint nbl_glsl_ext_FFT_Parameters_t_getPaddingType()
52+
{
5853
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
59-
return bool((params.dimension.w >> 8) & 0x000000ff);
54+
return params.input_dimensions.w&0x7u;
6055
}
61-
uint nbl_glsl_ext_FFT_Parameters_t_getPaddingType() {
56+
57+
uvec4 nbl_glsl_ext_FFT_Parameters_t_getInputStrides()
58+
{
6259
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
63-
return (params.dimension.w) & 0x000000ff;
60+
return params.input_strides;
6461
}
65-
uint nbl_glsl_ext_FFT_Parameters_t_getNumChannels() {
62+
63+
uvec4 nbl_glsl_ext_FFT_Parameters_t_getOutputStrides()
64+
{
6665
nbl_glsl_ext_FFT_Parameters_t params = nbl_glsl_ext_FFT_getParameters();
67-
return (params.padded_dimension.w);
66+
return params.output_strides;
6867
}
6968

7069
#endif

include/nbl/builtin/glsl/ext/FFT/parameters_struct.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
struct nbl_glsl_ext_FFT_Parameters_t
99
{
10-
uvec4 dimension; // settings packed into the w component : (direction_u8 << 16u) | (isInverse_u8 << 8u) | paddingType_u8;
11-
uvec4 padded_dimension; // num channels in the last channel (again the previous reasoning)
10+
uvec4 input_dimensions; // settings packed into the w component : (isInverse << 31u) | (direction_u3 << 28u) | (maxChannel_u2<<26u) | (fftSizeLog2_u5 << 3u) | paddingType_u3;
11+
uvec4 input_strides;
12+
uvec4 output_strides;
1213
};
1314

1415
#endif

include/nbl/ext/FFT/FFT.h

Lines changed: 46 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@ class FFT : public core::TotalInterface
3232
struct Parameters_t alignas(16) : nbl_glsl_ext_FFT_Parameters_t
3333
{
3434
};
35-
36-
enum class Direction : uint8_t
37-
{
38-
X = 0,
39-
Y = 1,
40-
Z = 2,
41-
};
4235

4336
enum class PaddingType : uint8_t
4437
{
@@ -49,46 +42,60 @@ class FFT : public core::TotalInterface
4942

5043
struct DispatchInfo_t
5144
{
52-
uint32_t workGroupDims[3];
5345
uint32_t workGroupCount[3];
5446
};
5547

5648
_NBL_STATIC_INLINE_CONSTEXPR uint32_t DEFAULT_WORK_GROUP_SIZE = 256u;
5749

58-
// returns dispatch size and fills the uniform data
59-
static inline DispatchInfo_t buildParameters(
60-
asset::VkExtent3D const & paddedInputDimensions,
61-
Direction direction)
50+
// returns how many dispatches necessary for computing the FFT and fills the uniform data
51+
static inline uint32_t buildParameters(bool isInverse, uint32_t numChannels, const asset::VkExtent3D& inputDimensions, Parameters_t* outParams, DispatchInfo_t* outInfos, const PaddingType* paddingType)
6252
{
63-
assert(core::isPoT(paddedInputDimensions.width) && core::isPoT(paddedInputDimensions.height) && core::isPoT(paddedInputDimensions.depth));
64-
DispatchInfo_t ret = {};
53+
uint32_t passesRequired = 0u;
6554

66-
ret.workGroupDims[0] = DEFAULT_WORK_GROUP_SIZE;
67-
ret.workGroupDims[1] = 1;
68-
ret.workGroupDims[2] = 1;
69-
70-
if(direction == Direction::X)
71-
{
72-
ret.workGroupCount[0] = 1;
73-
ret.workGroupCount[1] = paddedInputDimensions.height;
74-
ret.workGroupCount[2] = paddedInputDimensions.depth;
75-
}
76-
else if(direction == Direction::Y) {
77-
ret.workGroupCount[0] = paddedInputDimensions.width;
78-
ret.workGroupCount[1] = 1;
79-
ret.workGroupCount[2] = paddedInputDimensions.depth;
80-
}
81-
else if(direction == Direction::Z)
55+
if (numChannels)
8256
{
83-
ret.workGroupCount[0] = paddedInputDimensions.width;
84-
ret.workGroupCount[1] = paddedInputDimensions.height;
85-
ret.workGroupCount[2] = 1;
57+
const auto paddedInputDimensions = padDimensionToNextPOT(inputDimensions);
58+
for (uint32_t i=0u; i<3u; i++)
59+
if ((&inputDimensions.width)[i]>1u)
60+
{
61+
// TODO: rework
62+
auto& dispatch = outInfos[passesRequired];
63+
dispatch.workGroupCount[0] = paddedInputDimensions.width;
64+
dispatch.workGroupCount[1] = paddedInputDimensions.height;
65+
dispatch.workGroupCount[2] = paddedInputDimensions.depth;
66+
dispatch.workGroupCount[i] = 1u;
67+
68+
auto& params = outParams[passesRequired];
69+
params.input_dimensions.x = inputDimensions.width;
70+
params.input_dimensions.y = inputDimensions.height;
71+
params.input_dimensions.z = inputDimensions.depth;
72+
{
73+
// round up to workgroup size if too small
74+
const uint32_t fftSize = core::max(DEFAULT_WORK_GROUP_SIZE,(&paddedInputDimensions.width)[i]);
75+
76+
params.input_dimensions.w = (isInverse ? 0x80000000u:0x0u)|
77+
(i<<28u)| // direction
78+
((numChannels-1u)<<26u)| // max channel
79+
(core::findMSB(fftSize)<<3u)| // log2(fftSize)
80+
uint32_t(paddingType[i]);
81+
}
82+
params.input_strides.x = 1u;
83+
params.input_strides.y = paddedInputDimensions.width;
84+
params.input_strides.z = params.input_strides.y*paddedInputDimensions.height;
85+
params.input_strides.w = params.input_strides.z*paddedInputDimensions.depth;
86+
params.output_strides = params.input_strides;
87+
88+
passesRequired++;
89+
}
8690
}
8791

88-
return ret;
92+
if (passesRequired)
93+
outParams[passesRequired-1u].output_strides = outParams[0].input_strides;
94+
95+
return passesRequired;
8996
}
9097

91-
98+
// TODO: remove?
9299
static inline asset::VkExtent3D padDimensionToNextPOT(asset::VkExtent3D dimension, asset::VkExtent3D const & minimum_dimension = asset::VkExtent3D{ 1, 1, 1 })
93100
{
94101
if(dimension.width < minimum_dimension.width)
@@ -114,7 +121,7 @@ class FFT : public core::TotalInterface
114121
//
115122
static core::smart_refctd_ptr<video::IGPUPipelineLayout> getDefaultPipelineLayout(video::IVideoDriver* driver);
116123

117-
//
124+
// TODO: rework?
118125
static inline size_t getOutputBufferSize(asset::VkExtent3D const & paddedInputDimensions, uint32_t numChannels)
119126
{
120127
assert(core::isPoT(paddedInputDimensions.width) && core::isPoT(paddedInputDimensions.height) && core::isPoT(paddedInputDimensions.depth));
@@ -162,45 +169,18 @@ class FFT : public core::TotalInterface
162169

163170
static inline void dispatchHelper(
164171
video::IVideoDriver* driver,
172+
const video::IGPUPipelineLayout* pipelineLayout,
173+
const Parameters_t& params,
165174
const DispatchInfo_t& dispatchInfo,
166175
bool issueDefaultBarrier=true)
167176
{
177+
driver->pushConstants(pipelineLayout,video::IGPUSpecializedShader::ESS_COMPUTE,0u,sizeof(Parameters_t),&params);
168178
driver->dispatch(dispatchInfo.workGroupCount[0], dispatchInfo.workGroupCount[1], dispatchInfo.workGroupCount[2]);
169179

170180
if (issueDefaultBarrier)
171181
defaultBarrier();
172182
}
173183

174-
static inline void pushConstants(
175-
video::IVideoDriver* driver,
176-
const video::IGPUPipelineLayout * pipelineLayout,
177-
asset::VkExtent3D const & inputDimension,
178-
asset::VkExtent3D const & paddedInputDimension,
179-
Direction direction,
180-
bool isInverse,
181-
uint32_t numChannels,
182-
PaddingType paddingType = PaddingType::CLAMP_TO_EDGE)
183-
{
184-
185-
uint8_t isInverse_u8 = isInverse;
186-
uint8_t direction_u8 = static_cast<uint8_t>(direction);
187-
uint8_t paddingType_u8 = static_cast<uint8_t>(paddingType);
188-
189-
uint32_t packed = (direction_u8 << 16u) | (isInverse_u8 << 8u) | paddingType_u8;
190-
191-
Parameters_t params = {};
192-
params.dimension.x = inputDimension.width;
193-
params.dimension.y = inputDimension.height;
194-
params.dimension.z = inputDimension.depth;
195-
params.dimension.w = packed;
196-
params.padded_dimension.x = paddedInputDimension.width;
197-
params.padded_dimension.y = paddedInputDimension.height;
198-
params.padded_dimension.z = paddedInputDimension.depth;
199-
params.padded_dimension.w = numChannels;
200-
201-
driver->pushConstants(pipelineLayout, nbl::video::IGPUSpecializedShader::ESS_COMPUTE, 0u, sizeof(Parameters_t), &params);
202-
}
203-
204184
static void defaultBarrier();
205185

206186
private:

0 commit comments

Comments
 (0)