Skip to content

Commit 2df8fca

Browse files
more cleanup
1 parent 35f0046 commit 2df8fca

File tree

3 files changed

+30
-31
lines changed

3 files changed

+30
-31
lines changed

examples_tests/49.ComputeFFT/main.cpp

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ int main()
290290
auto srcImageBundle = am->getAsset("../../media/colorexr.exr", lp);
291291
auto kerImageBundle = am->getAsset("../../media/kernels/physical_flare_512.exr", lp);
292292

293+
// get GPU image views
293294
smart_refctd_ptr<IGPUImageView> srcImageView;
294295
{
295296
auto srcGpuImages = driver->getGPUObjectsFromAssets<ICPUImage>(srcImageBundle.getContents());
@@ -323,6 +324,7 @@ int main()
323324
kerImageView = driver->createGPUImageView(std::move(kerImgViewInfo));
324325
}
325326

327+
// agree on formats
326328
using FFTClass = ext::FFT::FFT;
327329

328330
const E_FORMAT srcFormat = srcImageView->getCreationParameters().format;
@@ -347,19 +349,14 @@ int main()
347349
dstImgViewInfo.image = outImg;
348350
outImgView = driver->createGPUImageView(IGPUImageView::SCreationParams(dstImgViewInfo));
349351
}
350-
auto fftGPUSpecializedShader_ImageInput = FFTClass::createShader(driver, FFTClass::DataType::TEXTURE2D, srcDim.width);
351-
352-
auto fftPipelineLayout_ImageInput = FFTClass::getDefaultPipelineLayout(driver, FFTClass::DataType::TEXTURE2D);
353-
auto fftPipeline_ImageInput = driver->createGPUComputePipeline(nullptr, core::smart_refctd_ptr(fftPipelineLayout_ImageInput), std::move(fftGPUSpecializedShader_ImageInput));
352+
353+
// input pipeline
354+
auto fftPipeline_ImageInput = FFTClass::getDefaultPipeline(driver,FFTClass::DataType::TEXTURE2D,srcDim.width);
354355

355356
const VkExtent3D paddedDim = FFTClass::padDimensionToNextPOT(srcDim);
356-
auto convolveShader = createShader_Convolution(driver, am, paddedDim.height);
357357
auto convolvePipelineLayout = getPipelineLayout_Convolution(driver);
358-
auto convolvePipeline = driver->createGPUComputePipeline(nullptr, core::smart_refctd_ptr(convolvePipelineLayout), std::move(convolveShader));
359-
360-
auto lastFFTShader = createShader_LastFFT(driver, am, paddedDim.width);
361-
auto lastFFTPipelineLayout = getPipelineLayout_LastFFT(driver);
362-
auto lastFFTPipeline = driver->createGPUComputePipeline(nullptr, core::smart_refctd_ptr(lastFFTPipelineLayout), std::move(lastFFTShader));
358+
auto convolvePipeline = driver->createGPUComputePipeline(nullptr, core::smart_refctd_ptr(convolvePipelineLayout), createShader_Convolution(driver, am, paddedDim.height));
359+
auto lastFFTPipeline = driver->createGPUComputePipeline(nullptr, getPipelineLayout_LastFFT(driver), createShader_LastFFT(driver,am,paddedDim.width));
363360

364361
// Allocate Output Buffer
365362
auto fftOutputBuffer_0 = driver->createDeviceLocalGPUBufferOnDedMem(FFTClass::getOutputBufferSize(paddedDim, srcNumChannels)); // result of: srcFFTX and kerFFTX and Convolution and IFFTY
@@ -398,12 +395,12 @@ int main()
398395
kernelNormalizedSpectrums[i] = createKernelSpectrum();
399396

400397
// Ker FFT X
401-
auto fftDescriptorSet_Ker_FFT_X = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(fftPipelineLayout_ImageInput->getDescriptorSetLayout(0u)));
398+
auto fftDescriptorSet_Ker_FFT_X = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(fftPipeline_ImageInput->getLayout()->getDescriptorSetLayout(0u)));
402399
FFTClass::updateDescriptorSet(driver, fftDescriptorSet_Ker_FFT_X.get(), kerImageView, fftOutputBuffer_0, ISampler::ETC_CLAMP_TO_BORDER);
403400

404401
// Ker FFT Y
405-
auto fftPipelineLayout_SSBOInput = FFTClass::getDefaultPipelineLayout(driver, FFTClass::DataType::SSBO);
406-
auto fftDescriptorSet_Ker_FFT_Y = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(fftPipelineLayout_SSBOInput->getDescriptorSetLayout(0u)));
402+
auto fftPipeline_SSBOInput = FFTClass::getDefaultPipeline(driver,FFTClass::DataType::SSBO,kerDim.height);
403+
auto fftDescriptorSet_Ker_FFT_Y = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(fftPipeline_SSBOInput->getLayout()->getDescriptorSetLayout(0u)));
407404
FFTClass::updateDescriptorSet(driver, fftDescriptorSet_Ker_FFT_Y.get(), fftOutputBuffer_0, fftOutputBuffer_1);
408405

409406
// Normalization of FFT Y result
@@ -475,15 +472,14 @@ int main()
475472

476473
// Ker Image FFT X
477474
driver->bindComputePipeline(fftPipeline_ImageInput.get());
478-
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_ImageInput.get(), 0u, 1u, &fftDescriptorSet_Ker_FFT_X.get(), nullptr);
479-
FFTClass::pushConstants(driver, fftPipelineLayout_ImageInput.get(), kerDim, paddedKerDim, FFTClass::Direction::X, false, srcNumChannels, FFTClass::PaddingType::FILL_WITH_ZERO);
475+
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipeline_ImageInput->getLayout(), 0u, 1u, &fftDescriptorSet_Ker_FFT_X.get(), nullptr);
476+
FFTClass::pushConstants(driver, fftPipeline_ImageInput->getLayout(), kerDim, paddedKerDim, FFTClass::Direction::X, false, srcNumChannels, FFTClass::PaddingType::FILL_WITH_ZERO);
480477
FFTClass::dispatchHelper(driver, fftDispatchInfo_Horizontal);
481478

482479
// Ker Image FFT Y
483-
auto fftPipeline_SSBOInput = driver->createGPUComputePipeline(nullptr, core::smart_refctd_ptr(fftPipelineLayout_SSBOInput), FFTClass::createShader(driver,FFTClass::DataType::SSBO,kerDim.height));
484480
driver->bindComputePipeline(fftPipeline_SSBOInput.get());
485-
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_SSBOInput.get(), 0u, 1u, &fftDescriptorSet_Ker_FFT_Y.get(), nullptr);
486-
FFTClass::pushConstants(driver, fftPipelineLayout_SSBOInput.get(), paddedKerDim, paddedKerDim, FFTClass::Direction::Y, false, srcNumChannels);
481+
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipeline_SSBOInput->getLayout(), 0u, 1u, &fftDescriptorSet_Ker_FFT_Y.get(), nullptr);
482+
FFTClass::pushConstants(driver, fftPipeline_SSBOInput->getLayout(), paddedKerDim, paddedKerDim, FFTClass::Direction::Y, false, srcNumChannels);
487483
FFTClass::dispatchHelper(driver, fftDispatchInfo_Vertical);
488484

489485
// Ker Normalization
@@ -507,15 +503,15 @@ int main()
507503
}
508504

509505
// Src FFT X
510-
auto fftDescriptorSet_Src_FFT_X = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(fftPipelineLayout_ImageInput->getDescriptorSetLayout(0u)));
506+
auto fftDescriptorSet_Src_FFT_X = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(fftPipeline_ImageInput->getLayout()->getDescriptorSetLayout(0u)));
511507
FFTClass::updateDescriptorSet(driver, fftDescriptorSet_Src_FFT_X.get(), srcImageView, fftOutputBuffer_0, ISampler::ETC_MIRROR);
512508

513509
// Convolution
514-
auto convolveDescriptorSet = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(convolvePipelineLayout->getDescriptorSetLayout(0u)));
510+
auto convolveDescriptorSet = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(convolvePipeline->getLayout()->getDescriptorSetLayout(0u)));
515511
updateDescriptorSet_Convolution(driver, convolveDescriptorSet.get(), fftOutputBuffer_0, kernelNormalizedSpectrums);
516512

517513
// Last IFFTX
518-
auto lastFFTDescriptorSet = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(lastFFTPipelineLayout->getDescriptorSetLayout(0u)));
514+
auto lastFFTDescriptorSet = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(lastFFTPipeline->getLayout()->getDescriptorSetLayout(0u)));
519515
updateDescriptorSet_LastFFT(driver, lastFFTDescriptorSet.get(), fftOutputBuffer_0, outImgView);
520516

521517
uint32_t outBufferIx = 0u;
@@ -536,20 +532,20 @@ int main()
536532

537533
// Src Image FFT X
538534
driver->bindComputePipeline(fftPipeline_ImageInput.get());
539-
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_ImageInput.get(), 0u, 1u, &fftDescriptorSet_Src_FFT_X.get(), nullptr);
540-
FFTClass::pushConstants(driver, fftPipelineLayout_ImageInput.get(), srcDim, paddedDim, FFTClass::Direction::X, false, srcNumChannels, FFTClass::PaddingType::CLAMP_TO_EDGE);
535+
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipeline_ImageInput->getLayout(), 0u, 1u, &fftDescriptorSet_Src_FFT_X.get(), nullptr);
536+
FFTClass::pushConstants(driver, fftPipeline_ImageInput->getLayout(), srcDim, paddedDim, FFTClass::Direction::X, false, srcNumChannels, FFTClass::PaddingType::CLAMP_TO_EDGE);
541537
FFTClass::dispatchHelper(driver, fftDispatchInfo_Horizontal);
542538

543539
// Src Image FFT Y + Convolution + Convolved IFFT Y
544540
driver->bindComputePipeline(convolvePipeline.get());
545-
driver->bindDescriptorSets(EPBP_COMPUTE, convolvePipelineLayout.get(), 0u, 1u, &convolveDescriptorSet.get(), nullptr);
546-
FFTClass::pushConstants(driver, convolvePipelineLayout.get(), paddedDim, paddedDim, FFTClass::Direction::Y, false, srcNumChannels);
541+
driver->bindDescriptorSets(EPBP_COMPUTE, convolvePipeline->getLayout(), 0u, 1u, &convolveDescriptorSet.get(), nullptr);
542+
FFTClass::pushConstants(driver, convolvePipeline->getLayout(), paddedDim, paddedDim, FFTClass::Direction::Y, false, srcNumChannels);
547543
FFTClass::dispatchHelper(driver, fftDispatchInfo_Vertical);
548544

549545
// Last FFT Padding and Copy to GPU Image
550546
driver->bindComputePipeline(lastFFTPipeline.get());
551-
driver->bindDescriptorSets(EPBP_COMPUTE, lastFFTPipelineLayout.get(), 0u, 1u, &lastFFTDescriptorSet.get(), nullptr);
552-
FFTClass::pushConstants(driver, lastFFTPipelineLayout.get(), paddedDim, paddedDim, FFTClass::Direction::X, true, srcNumChannels);
547+
driver->bindDescriptorSets(EPBP_COMPUTE, lastFFTPipeline->getLayout(), 0u, 1u, &lastFFTDescriptorSet.get(), nullptr);
548+
FFTClass::pushConstants(driver, lastFFTPipeline->getLayout(), paddedDim, paddedDim, FFTClass::Direction::X, true, srcNumChannels);
553549
FFTClass::dispatchHelper(driver, fftDispatchInfo_Horizontal);
554550

555551
if(false == savedToFile) {

include/nbl/ext/FFT/FFT.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class FFT : public core::TotalInterface
123123
return (paddedInputDimensions.width * paddedInputDimensions.height * paddedInputDimensions.depth * numChannels) * (sizeof(float) * 2);
124124
}
125125

126-
static core::smart_refctd_ptr<video::IGPUSpecializedShader> createShader(video::IVideoDriver* driver, DataType inputType, uint32_t maxDimensionSize);
126+
static core::smart_refctd_ptr<video::IGPUComputePipeline> getDefaultPipeline(video::IVideoDriver* driver, DataType inputType, uint32_t maxDimensionSize);
127127

128128
_NBL_STATIC_INLINE_CONSTEXPR uint32_t MAX_DESCRIPTOR_COUNT = 2u;
129129
static inline void updateDescriptorSet(
@@ -214,7 +214,7 @@ class FFT : public core::TotalInterface
214214

215215
static inline void pushConstants(
216216
video::IVideoDriver* driver,
217-
video::IGPUPipelineLayout * pipelineLayout,
217+
const video::IGPUPipelineLayout * pipelineLayout,
218218
asset::VkExtent3D const & inputDimension,
219219
asset::VkExtent3D const & paddedInputDimension,
220220
Direction direction,

src/nbl/ext/FFT/FFT.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,21 +73,24 @@ core::smart_refctd_ptr<IGPUDescriptorSetLayout> FFT::getDefaultDescriptorSetLayo
7373
bnd[0].type = EDT_COMBINED_IMAGE_SAMPLER;
7474
else
7575
bnd[0].type = EDT_STORAGE_BUFFER;
76+
// TODO: cache using the asset manager's caches
7677
return driver->createGPUDescriptorSetLayout(bnd,bnd+sizeof(bnd)/sizeof(IGPUDescriptorSetLayout::SBinding));
7778
}
7879

7980
//
8081
core::smart_refctd_ptr<IGPUPipelineLayout> FFT::getDefaultPipelineLayout(IVideoDriver* driver, FFT::DataType inputType)
8182
{
8283
auto pcRange = getDefaultPushConstantRanges();
84+
// TODO: cache using the asset manager's caches
8385
return driver->createGPUPipelineLayout(
8486
pcRange.begin(),pcRange.end(),
8587
getDefaultDescriptorSetLayout(driver,inputType),nullptr,nullptr,nullptr
8688
);
8789
}
8890

89-
core::smart_refctd_ptr<IGPUSpecializedShader> FFT::createShader(IVideoDriver* driver, DataType inputType, uint32_t maxDimensionSize)
91+
core::smart_refctd_ptr<video::IGPUComputePipeline> FFT::getDefaultPipeline(video::IVideoDriver* driver, DataType inputType, uint32_t maxDimensionSize)
9092
{
93+
// TODO: cache using the asset manager's caches
9194
uint32_t const maxPaddedDimensionSize = core::roundUpToPoT(maxDimensionSize);
9295

9396
const char* sourceFmt =
@@ -119,7 +122,7 @@ R"===(#version 430 core
119122
ISpecializedShader::SInfo{nullptr, nullptr, "main", ISpecializedShader::ESS_COMPUTE}
120123
);
121124

122-
return specializedShader;
125+
return driver->createGPUComputePipeline(nullptr, getDefaultPipelineLayout(driver,inputType), std::move(specializedShader));
123126
}
124127

125128
void FFT::defaultBarrier()

0 commit comments

Comments
 (0)