Skip to content

Commit 85c79fb

Browse files
axis swaps work
1 parent 197da09 commit 85c79fb

File tree

2 files changed

+63
-61
lines changed

2 files changed

+63
-61
lines changed

examples_tests/49.ComputeFFT/main.cpp

Lines changed: 53 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ constexpr uint32_t channelCountOverride = 3u;
2424

2525
inline core::smart_refctd_ptr<video::IGPUSpecializedShader> createShader(
2626
video::IVideoDriver* driver,
27-
const FFTClass* fft,
27+
const uint32_t maxFFTlen,
28+
const bool useHalfStorage,
2829
const char* includeMainName,
2930
float kernelScale = 1.f)
3031
{
@@ -48,8 +49,8 @@ R"===(#version 430 core
4849
snprintf(
4950
reinterpret_cast<char*>(shader->getPointer()),shader->getSize(), sourceFmt,
5051
DEFAULT_WORK_GROUP_SIZE,
51-
fft->getMaxFFTLength(),
52-
fft->usesHalfFloatStorage() ? 1u:0u,
52+
maxFFTlen,
53+
useHalfStorage ? 1u:0u,
5354
kernelScale,
5455
includeMainName
5556
);
@@ -456,18 +457,26 @@ int main()
456457
};
457458
for (uint32_t i=0u; i<channelCountOverride; i++)
458459
kernelNormalizedSpectrums[i] = createKernelSpectrum();
460+
459461

460-
// Ker FFT X
461-
auto fftDescriptorSet_Ker_FFT_X = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(imageFirstFFTPipelineLayout->getDescriptorSetLayout(0u)));
462-
updateDescriptorSet(fftDescriptorSet_Ker_FFT_X.get(), kerImageView, ISampler::ETC_CLAMP_TO_BORDER, fftOutputBuffer_0);
462+
FFTClass::Parameters_t fftPushConstants[2];
463+
FFTClass::DispatchInfo_t fftDispatchInfo[2];
464+
const ISampler::E_TEXTURE_CLAMP fftPadding[2] = {ISampler::ETC_CLAMP_TO_BORDER,ISampler::ETC_CLAMP_TO_BORDER};
465+
const auto passes = FFTClass::buildParameters(false,srcNumChannels,kerDim,fftPushConstants,fftDispatchInfo,fftPadding);
466+
assert(passes==2u);
467+
// last FFT pipeline
468+
core::smart_refctd_ptr<IGPUComputePipeline> fftPipeline_SSBOInput(core::make_smart_refctd_ptr<FFTClass>(driver,0x1u<<fftPushConstants[1].getLog2FFTSize(),useHalfFloats)->getDefaultPipeline());
463469

464-
// Ker FFT Y
465-
auto fft_y = core::make_smart_refctd_ptr<FFTClass>(driver,kerDim.height,useHalfFloats);
466-
auto fftPipeline_SSBOInput = fft_y->getDefaultPipeline();
467-
auto fftDescriptorSet_Ker_FFT_Y = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(fftPipeline_SSBOInput->getLayout()->getDescriptorSetLayout(0u)));
468-
FFTClass::updateDescriptorSet(driver, fftDescriptorSet_Ker_FFT_Y.get(), fftOutputBuffer_0, fftOutputBuffer_1);
470+
// descriptor sets
471+
core::smart_refctd_ptr<IGPUDescriptorSet> fftDescriptorSet_Ker_FFT[2] =
472+
{
473+
driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(imageFirstFFTPipelineLayout->getDescriptorSetLayout(0u))),
474+
driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(fftPipeline_SSBOInput->getLayout()->getDescriptorSetLayout(0u)))
475+
};
476+
updateDescriptorSet(fftDescriptorSet_Ker_FFT[0].get(), kerImageView, ISampler::ETC_CLAMP_TO_BORDER, fftOutputBuffer_0);
477+
FFTClass::updateDescriptorSet(driver,fftDescriptorSet_Ker_FFT[1].get(), fftOutputBuffer_0, fftOutputBuffer_1);
469478

470-
// Normalization of FFT Y result
479+
// Normalization of FFT spectrum
471480
struct NormalizationPushConstants
472481
{
473482
ext::FFT::uvec4 stride;
@@ -540,28 +549,21 @@ int main()
540549
return dset;
541550
}();
542551

543-
FFTClass::Parameters_t fftPushConstants[2];
544-
FFTClass::DispatchInfo_t fftDispatchInfo[2];
545-
const ISampler::E_TEXTURE_CLAMP fftPadding[2] = {ISampler::ETC_CLAMP_TO_BORDER,ISampler::ETC_CLAMP_TO_BORDER};
546-
const auto passes = FFTClass::buildParameters(false,srcNumChannels,kerDim,fftPushConstants,fftDispatchInfo,fftPadding);
547-
assert(passes==2u);
548-
549-
// Ker Image FFT X
550-
auto fft_x = core::make_smart_refctd_ptr<FFTClass>(driver, kerDim.height, useHalfFloats);
552+
// Ker Image First Axis FFT
551553
{
552-
auto fftPipeline_ImageInput = driver->createGPUComputePipeline(nullptr,core::smart_refctd_ptr(imageFirstFFTPipelineLayout),createShader(driver,fft_x.get(),"../image_first_fft.comp",bloomScale));
554+
auto fftPipeline_ImageInput = driver->createGPUComputePipeline(nullptr,core::smart_refctd_ptr(imageFirstFFTPipelineLayout),createShader(driver,0x1u<<fftPushConstants[0].getLog2FFTSize(),useHalfFloats,"../image_first_fft.comp",bloomScale));
553555
driver->bindComputePipeline(fftPipeline_ImageInput.get());
554-
driver->bindDescriptorSets(EPBP_COMPUTE, imageFirstFFTPipelineLayout.get(), 0u, 1u, &fftDescriptorSet_Ker_FFT_X.get(), nullptr);
556+
driver->bindDescriptorSets(EPBP_COMPUTE, imageFirstFFTPipelineLayout.get(), 0u, 1u, &fftDescriptorSet_Ker_FFT[0].get(), nullptr);
555557
FFTClass::dispatchHelper(driver, imageFirstFFTPipelineLayout.get(), fftPushConstants[0], fftDispatchInfo[0]);
556558
}
557559

558-
// Ker Image FFT Y
559-
driver->bindComputePipeline(fftPipeline_SSBOInput);
560-
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipeline_SSBOInput->getLayout(), 0u, 1u, &fftDescriptorSet_Ker_FFT_Y.get(), nullptr);
560+
// Ker Image Last Axis FFT
561+
driver->bindComputePipeline(fftPipeline_SSBOInput.get());
562+
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipeline_SSBOInput->getLayout(), 0u, 1u, &fftDescriptorSet_Ker_FFT[1].get(), nullptr);
561563
FFTClass::dispatchHelper(driver, fftPipeline_SSBOInput->getLayout(), fftPushConstants[1], fftDispatchInfo[1]);
562564

563565
// Ker Normalization
564-
auto fftPipeline_KernelNormalization = driver->createGPUComputePipeline(nullptr,core::smart_refctd_ptr(fftPipelineLayout_KernelNormalization),createShader(driver,fft_x.get(),"../normalization.comp"));
566+
auto fftPipeline_KernelNormalization = driver->createGPUComputePipeline(nullptr,core::smart_refctd_ptr(fftPipelineLayout_KernelNormalization),createShader(driver,0xdeadbeefu,useHalfFloats,"../normalization.comp"));
565567
driver->bindComputePipeline(fftPipeline_KernelNormalization.get());
566568
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_KernelNormalization.get(), 0u, 1u, &fftDescriptorSet_KernelNormalization.get(), nullptr);
567569
{
@@ -580,22 +582,35 @@ int main()
580582
}
581583
}
582584

585+
FFTClass::Parameters_t fftPushConstants[3];
586+
FFTClass::DispatchInfo_t fftDispatchInfo[3];
587+
const ISampler::E_TEXTURE_CLAMP fftPadding[2] = {ISampler::ETC_MIRROR,ISampler::ETC_MIRROR};
588+
const auto passes = FFTClass::buildParameters(false,srcNumChannels,srcDim,fftPushConstants,fftDispatchInfo,fftPadding,marginSrcDim);
589+
{
590+
fftPushConstants[1].output_strides = fftPushConstants[1].input_strides; // override for less work and storage (dont need to store the extra padding of the last axis after iFFT)
591+
fftPushConstants[2].input_dimensions = fftPushConstants[1].input_dimensions;
592+
{
593+
fftPushConstants[2].input_dimensions.w = fftPushConstants[0].input_dimensions.w^0x80000000u;
594+
fftPushConstants[2].input_strides = fftPushConstants[1].output_strides;
595+
fftPushConstants[2].output_strides = fftPushConstants[0].input_strides;
596+
}
597+
fftDispatchInfo[2] = fftDispatchInfo[0];
598+
}
599+
assert(passes==2);
583600
// pipelines
584-
auto fft_x = core::make_smart_refctd_ptr<FFTClass>(driver,marginSrcDim.width,useHalfFloats);
585-
auto fft_y = core::make_smart_refctd_ptr<FFTClass>(driver,marginSrcDim.height,useHalfFloats);
586-
auto fftPipeline_ImageInput = driver->createGPUComputePipeline(nullptr,core::smart_refctd_ptr(imageFirstFFTPipelineLayout),createShader(driver,fft_x.get(), "../image_first_fft.comp"));
587-
auto convolvePipeline = driver->createGPUComputePipeline(nullptr, std::move(convolvePipelineLayout), createShader(driver,fft_y.get(), "../fft_convolve_ifft.comp"));
588-
auto lastFFTPipeline = driver->createGPUComputePipeline(nullptr, std::move(lastFFTPipelineLayout), createShader(driver,fft_x.get(), "../last_fft.comp"));
601+
auto fftPipeline_ImageInput = driver->createGPUComputePipeline(nullptr,core::smart_refctd_ptr(imageFirstFFTPipelineLayout),createShader(driver,0x1u<<fftPushConstants[0].getLog2FFTSize(),useHalfFloats,"../image_first_fft.comp"));
602+
auto convolvePipeline = driver->createGPUComputePipeline(nullptr, std::move(convolvePipelineLayout), createShader(driver,0x1u<<fftPushConstants[1].getLog2FFTSize(),useHalfFloats, "../fft_convolve_ifft.comp"));
603+
auto lastFFTPipeline = driver->createGPUComputePipeline(nullptr, std::move(lastFFTPipelineLayout), createShader(driver,0x1u<<fftPushConstants[0].getLog2FFTSize(),useHalfFloats,"../last_fft.comp"));
589604

590-
// Src FFT X
591-
auto fftDescriptorSet_Src_FFT_X = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(imageFirstFFTPipelineLayout->getDescriptorSetLayout(0u)));
592-
updateDescriptorSet(fftDescriptorSet_Src_FFT_X.get(), srcImageView, ISampler::ETC_MIRROR, fftOutputBuffer_0);
605+
// Src First Axis FFT
606+
auto fftDescriptorSet_Src_FirstFFT = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(imageFirstFFTPipelineLayout->getDescriptorSetLayout(0u)));
607+
updateDescriptorSet(fftDescriptorSet_Src_FirstFFT.get(), srcImageView, ISampler::ETC_MIRROR, fftOutputBuffer_0);
593608

594609
// Convolution
595610
auto convolveDescriptorSet = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(convolvePipeline->getLayout()->getDescriptorSetLayout(0u)));
596611
updateDescriptorSet_Convolution(driver, convolveDescriptorSet.get(), fftOutputBuffer_0, fftOutputBuffer_1, kernelNormalizedSpectrums);
597612

598-
// Last IFFTX
613+
// Last Axis IFFT
599614
auto lastFFTDescriptorSet = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(lastFFTPipeline->getLayout()->getDescriptorSetLayout(0u)));
600615
updateDescriptorSet_LastFFT(driver, lastFFTDescriptorSet.get(), fftOutputBuffer_1, outImgView);
601616

@@ -608,33 +623,16 @@ int main()
608623
auto blitFBO = driver->addFrameBuffer();
609624
blitFBO->attach(video::EFAP_COLOR_ATTACHMENT0, std::move(outImgView));
610625

611-
612-
FFTClass::Parameters_t fftPushConstants[3];
613-
FFTClass::DispatchInfo_t fftDispatchInfo[3];
614-
const ISampler::E_TEXTURE_CLAMP fftPadding[2] = {ISampler::ETC_MIRROR,ISampler::ETC_MIRROR};
615-
const auto passes = FFTClass::buildParameters(false,srcNumChannels,srcDim,fftPushConstants,fftDispatchInfo,fftPadding,marginSrcDim);
616-
{
617-
fftPushConstants[1].output_strides = fftPushConstants[1].input_strides; // override for less work and storage (dont need to store the extra Y-slices after iFFT)
618-
fftPushConstants[2].input_dimensions = fftPushConstants[1].input_dimensions;
619-
{
620-
fftPushConstants[2].input_dimensions.w = fftPushConstants[0].input_dimensions.w^0x80000000u;
621-
fftPushConstants[2].input_strides = fftPushConstants[1].output_strides;
622-
fftPushConstants[2].output_strides = fftPushConstants[0].input_strides;
623-
}
624-
fftDispatchInfo[2] = fftDispatchInfo[0];
625-
}
626-
assert(passes==2);
627-
628626
while (device->run() && receiver.keepOpen())
629627
{
630628
driver->beginScene(false, false);
631629

632-
// Src Image FFT X
630+
// Src Image First Axis FFT
633631
driver->bindComputePipeline(fftPipeline_ImageInput.get());
634-
driver->bindDescriptorSets(EPBP_COMPUTE, imageFirstFFTPipelineLayout.get(), 0u, 1u, &fftDescriptorSet_Src_FFT_X.get(), nullptr);
632+
driver->bindDescriptorSets(EPBP_COMPUTE, imageFirstFFTPipelineLayout.get(), 0u, 1u, &fftDescriptorSet_Src_FirstFFT.get(), nullptr);
635633
FFTClass::dispatchHelper(driver, imageFirstFFTPipelineLayout.get(), fftPushConstants[0], fftDispatchInfo[0]);
636634

637-
// Src Image FFT Y + Convolution + Convolved IFFT Y
635+
// Src Image Last Axis FFT + Convolution + Convolved Last Axis IFFT Y
638636
driver->bindComputePipeline(convolvePipeline.get());
639637
driver->bindDescriptorSets(EPBP_COMPUTE, convolvePipeline->getLayout(), 0u, 1u, &convolveDescriptorSet.get(), nullptr);
640638
{

include/nbl/ext/FFT/FFT.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class FFT final : public core::IReferenceCounted
3232
public:
3333
struct Parameters_t alignas(16) : nbl_glsl_ext_FFT_Parameters_t
3434
{
35+
inline uint getLog2FFTSize()
36+
{
37+
return (input_dimensions.w>>3u)&0x1fu;
38+
}
3539
};
3640

3741
struct DispatchInfo_t
@@ -53,7 +57,7 @@ class FFT final : public core::IReferenceCounted
5357

5458
const auto paddedInputDimensions = padDimensions(extraPaddedInputDimensions);
5559

56-
using SizeAxisPair = std::tuple<uint32_t,uint8_t,uint8_t>;
60+
using SizeAxisPair = std::tuple<float,uint8_t,uint8_t>;
5761
std::array<SizeAxisPair,3u> passes;
5862
if (numChannels)
5963
{
@@ -62,9 +66,9 @@ class FFT final : public core::IReferenceCounted
6266
auto dim = (&paddedInputDimensions.width)[i];
6367
if (dim<2u)
6468
continue;
65-
passes[passesRequired++] = {dim,i,paddingType[i]};
69+
passes[passesRequired++] = {float(dim)/float((&inputDimensions.width)[i]),i,paddingType[i]};
6670
}
67-
std::sort(passes.begin(),passes.begin()+passesRequired,[](const auto& lhs, const auto& rhs)->bool{return std::get<0u>(lhs)>std::get<0u>(rhs);});
71+
std::sort(passes.begin(),passes.begin()+passesRequired);
6872
}
6973

7074
auto computeOutputStride = [](const uvec3& output_dimensions, const auto axis, const auto nextAxis) -> uvec4
@@ -99,17 +103,17 @@ class FFT final : public core::IReferenceCounted
99103
params.input_dimensions.y = output_dimensions.y;
100104
params.input_dimensions.z = output_dimensions.z;
101105

102-
const auto paddedAxisLen = std::get<0u>(passes[i]);
106+
const auto passAxis = std::get<1u>(passes[i]);
107+
const auto paddedAxisLen = (&paddedInputDimensions.width)[passAxis];
103108
{
104109
assert(paddingType[i]<=asset::ISampler::ETC_MIRROR);
105110
params.input_dimensions.w = (isInverse ? 0x80000000u:0x0u)|
106-
(i<<28u)| // direction
111+
(passAxis<<28u)| // direction
107112
((numChannels-1u)<<26u)| // max channel
108113
(core::findMSB(paddedAxisLen)<<3u)| // log2(fftSize)
109114
uint32_t(std::get<2u>(passes[i]));
110115
}
111116

112-
const auto passAxis = std::get<1u>(passes[i]);
113117
(&output_dimensions.x)[passAxis] = paddedAxisLen;
114118
if (i)
115119
params.input_strides = outParams[i-1u].output_strides;

0 commit comments

Comments
 (0)