@@ -21,12 +21,6 @@ using namespace nbl::video;
21
21
#include " nbl/core/math/intutil.h"
22
22
#include " nbl/core/math/glslFunctions.h"
23
23
24
- struct DispatchInfo_t
25
- {
26
- uint32_t workGroupDims[3 ];
27
- uint32_t workGroupCount[3 ];
28
- };
29
-
30
24
constexpr uint32_t channelCountOverride = 3u ;
31
25
32
26
inline core::smart_refctd_ptr<video::IGPUSpecializedShader> createShader (
@@ -517,23 +511,24 @@ int main()
517
511
return dset;
518
512
}();
519
513
520
- auto fftDispatchInfo_Horizontal = FFTClass::buildParameters (paddedKerDim, FFTClass::Direction::X);
521
- auto fftDispatchInfo_Vertical = FFTClass::buildParameters (paddedKerDim, FFTClass::Direction::Y);
514
+ FFTClass::Parameters_t fftPushConstants[2 ];
515
+ FFTClass::DispatchInfo_t fftDispatchInfo[2 ];
516
+ const FFTClass::PaddingType fftPadding[2 ] = {FFTClass::PaddingType::FILL_WITH_ZERO,FFTClass::PaddingType::FILL_WITH_ZERO};
517
+ const auto passes = FFTClass::buildParameters (false ,srcNumChannels,kerDim,fftPushConstants,fftDispatchInfo,fftPadding);
518
+ assert (passes==2u );
522
519
523
520
// Ker Image FFT X
524
521
{
525
522
auto fftPipeline_ImageInput = driver->createGPUComputePipeline (nullptr ,core::smart_refctd_ptr (imageFirstFFTPipelineLayout),createShader (driver, paddedKerDim.width , " ../image_first_fft.comp" ));
526
523
driver->bindComputePipeline (fftPipeline_ImageInput.get ());
527
524
driver->bindDescriptorSets (EPBP_COMPUTE, imageFirstFFTPipelineLayout.get (), 0u , 1u , &fftDescriptorSet_Ker_FFT_X.get (), nullptr );
528
- FFTClass::pushConstants (driver, imageFirstFFTPipelineLayout.get (), kerDim, paddedKerDim, FFTClass::Direction::X, false , srcNumChannels, FFTClass::PaddingType::FILL_WITH_ZERO);
529
- FFTClass::dispatchHelper (driver, fftDispatchInfo_Horizontal);
525
+ FFTClass::dispatchHelper (driver, imageFirstFFTPipelineLayout.get (), fftPushConstants[0 ], fftDispatchInfo[0 ]);
530
526
}
531
527
532
528
// Ker Image FFT Y
533
529
driver->bindComputePipeline (fftPipeline_SSBOInput.get ());
534
530
driver->bindDescriptorSets (EPBP_COMPUTE, fftPipeline_SSBOInput->getLayout (), 0u , 1u , &fftDescriptorSet_Ker_FFT_Y.get (), nullptr );
535
- FFTClass::pushConstants (driver, fftPipeline_SSBOInput->getLayout (), paddedKerDim, paddedKerDim, FFTClass::Direction::Y, false , srcNumChannels);
536
- FFTClass::dispatchHelper (driver, fftDispatchInfo_Vertical);
531
+ FFTClass::dispatchHelper (driver, fftPipeline_SSBOInput->getLayout (), fftPushConstants[1 ], fftDispatchInfo[1 ]);
537
532
538
533
// Ker Normalization
539
534
auto fftPipeline_KernelNormalization = driver->createGPUComputePipeline (nullptr , core::smart_refctd_ptr (fftPipelineLayout_KernelNormalization),
@@ -549,7 +544,7 @@ int main()
549
544
driver->bindDescriptorSets (EPBP_COMPUTE, fftPipelineLayout_KernelNormalization.get (), 0u , 1u , &fftDescriptorSet_KernelNormalization.get (), nullptr );
550
545
{
551
546
NormalizationPushConstants normalizationPC;
552
- normalizationPC.stride = { 1u ,paddedKerDim. width ,paddedKerDim. width *paddedKerDim. height ,paddedKerDim. width *paddedKerDim. height }; // TODO: take from the Y FFT pass
547
+ normalizationPC.stride = fftPushConstants[ 1 ]. output_strides ;
553
548
normalizationPC.bitreverse_shift .x = 32 -core::findMSB (paddedKerDim.width );
554
549
normalizationPC.bitreverse_shift .y = 32 -core::findMSB (paddedKerDim.height );
555
550
normalizationPC.bitreverse_shift .z = 0 ;
@@ -590,31 +585,41 @@ int main()
590
585
blitFBO->attach (video::EFAP_COLOR_ATTACHMENT0, std::move (outImgView));
591
586
592
587
593
- auto fftDispatchInfo_Horizontal = FFTClass::buildParameters (paddedDim, FFTClass::Direction::X);
594
- auto fftDispatchInfo_Vertical = FFTClass::buildParameters (paddedDim, FFTClass::Direction::Y);
588
+ FFTClass::Parameters_t fftPushConstants[3 ];
589
+ FFTClass::DispatchInfo_t fftDispatchInfo[3 ];
590
+ const FFTClass::PaddingType fftPadding[2 ] = {FFTClass::PaddingType::CLAMP_TO_EDGE,FFTClass::PaddingType::CLAMP_TO_EDGE}; // TODO
591
+ const auto passes = FFTClass::buildParameters (false ,srcNumChannels,srcDim,fftPushConstants,fftDispatchInfo,fftPadding);
592
+ {
593
+ fftPushConstants[2 ] = fftPushConstants[0 ];
594
+ {
595
+ fftPushConstants[2 ].input_dimensions .w ^= 0x80000000u ;
596
+ fftPushConstants[2 ].input_dimensions .w &= 0xfffffffdu ;
597
+ }
598
+ fftDispatchInfo[2 ] = fftDispatchInfo[0 ];
599
+ }
600
+ assert (passes==2 );
601
+
595
602
while (device->run () && receiver.keepOpen ())
596
603
{
597
604
driver->beginScene (false , false );
598
605
599
606
// Src Image FFT X
600
607
driver->bindComputePipeline (fftPipeline_ImageInput.get ());
601
608
driver->bindDescriptorSets (EPBP_COMPUTE, imageFirstFFTPipelineLayout.get (), 0u , 1u , &fftDescriptorSet_Src_FFT_X.get (), nullptr );
602
- FFTClass::pushConstants (driver, imageFirstFFTPipelineLayout.get (), srcDim, paddedDim, FFTClass::Direction::X, false , srcNumChannels, FFTClass::PaddingType::CLAMP_TO_EDGE);
603
- FFTClass::dispatchHelper (driver, fftDispatchInfo_Horizontal);
609
+ FFTClass::dispatchHelper (driver, imageFirstFFTPipelineLayout.get (), fftPushConstants[0 ], fftDispatchInfo[0 ]);
604
610
605
611
// Src Image FFT Y + Convolution + Convolved IFFT Y
606
612
driver->bindComputePipeline (convolvePipeline.get ());
607
613
driver->bindDescriptorSets (EPBP_COMPUTE, convolvePipeline->getLayout (), 0u , 1u , &convolveDescriptorSet.get (), nullptr );
608
- FFTClass::pushConstants (driver, convolvePipeline->getLayout (), paddedDim, paddedDim, FFTClass::Direction::Y, false , srcNumChannels);
609
- FFTClass::dispatchHelper (driver, fftDispatchInfo_Vertical);
614
+ FFTClass::dispatchHelper (driver, convolvePipeline->getLayout (), fftPushConstants[1 ], fftDispatchInfo[1 ]);
610
615
611
616
// Last FFT Padding and Copy to GPU Image
612
617
driver->bindComputePipeline (lastFFTPipeline.get ());
613
618
driver->bindDescriptorSets (EPBP_COMPUTE, lastFFTPipeline->getLayout (), 0u , 1u , &lastFFTDescriptorSet.get (), nullptr );
614
- FFTClass::pushConstants (driver, lastFFTPipeline->getLayout (), paddedDim, paddedDim, FFTClass::Direction::X, true , srcNumChannels);
615
- FFTClass::dispatchHelper (driver, fftDispatchInfo_Horizontal);
619
+ FFTClass::dispatchHelper (driver, lastFFTPipeline->getLayout (), fftPushConstants[2 ], fftDispatchInfo[2 ]);
616
620
617
- if (false == savedToFile) {
621
+ if (!savedToFile)
622
+ {
618
623
savedToFile = true ;
619
624
620
625
core::smart_refctd_ptr<ICPUImageView> imageView;
0 commit comments