Skip to content

Commit d406e4d

Browse files
rearrange the code slightly
1 parent 7fd1e25 commit d406e4d

File tree

2 files changed

+40
-59
lines changed

2 files changed

+40
-59
lines changed

examples_tests/49.ComputeFFT/main.cpp

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -301,83 +301,68 @@ int main()
301301

302302
IAssetLoader::SAssetLoadParams lp;
303303
auto srcImageBundle = am->getAsset("../../media/colorexr.exr", lp);
304-
auto srcCpuImg = IAsset::castDown<ICPUImage>(srcImageBundle.getContents().begin()[0]);
305304
auto kerImageBundle = am->getAsset("../../media/kernels/physical_flare_512.exr", lp);
306-
auto kerCpuImg = IAsset::castDown<ICPUImage>(kerImageBundle.getContents().begin()[0]);
307-
308-
IGPUImage::SCreationParams srcImgInfo;
309-
IGPUImage::SCreationParams kerImgInfo;
310-
311-
smart_refctd_ptr<IGPUImage> outImg;
312-
smart_refctd_ptr<IGPUImageView> outImgView;
313305

314306
smart_refctd_ptr<IGPUImageView> srcImageView;
315-
IGPUImageView::SCreationParams srcImgViewInfo;
316307
{
317-
srcImgInfo = srcCpuImg->getCreationParameters();
318-
319-
auto srcGpuImages = driver->getGPUObjectsFromAssets(&srcCpuImg.get(),&srcCpuImg.get()+1);
320-
auto srcGpuImage = srcGpuImages->operator[](0u);
308+
auto srcGpuImages = driver->getGPUObjectsFromAssets<ICPUImage>(srcImageBundle.getContents());
321309

310+
IGPUImageView::SCreationParams srcImgViewInfo;
322311
srcImgViewInfo.flags = static_cast<IGPUImageView::E_CREATE_FLAGS>(0u);
323-
srcImgViewInfo.image = std::move(srcGpuImage);
312+
srcImgViewInfo.image = srcGpuImages->operator[](0u);
324313
srcImgViewInfo.viewType = IGPUImageView::ET_2D;
325-
srcImgViewInfo.format = srcImgInfo.format;
314+
srcImgViewInfo.format = srcImgViewInfo.image->getCreationParameters().format;
326315
srcImgViewInfo.subresourceRange.aspectMask = static_cast<IImage::E_ASPECT_FLAGS>(0u);
327316
srcImgViewInfo.subresourceRange.baseMipLevel = 0;
328317
srcImgViewInfo.subresourceRange.levelCount = 1;
329318
srcImgViewInfo.subresourceRange.baseArrayLayer = 0;
330319
srcImgViewInfo.subresourceRange.layerCount = 1;
331-
srcImageView = driver->createGPUImageView(IGPUImageView::SCreationParams(srcImgViewInfo));
320+
srcImageView = driver->createGPUImageView(std::move(srcImgViewInfo));
332321
}
333322
smart_refctd_ptr<IGPUImageView> kerImageView;
334323
{
335-
kerImgInfo = kerCpuImg->getCreationParameters();
336-
337-
auto kerGpuImages = driver->getGPUObjectsFromAssets(&kerCpuImg.get(),&kerCpuImg.get()+1);
338-
auto kerGpuImage = kerGpuImages->operator[](0u);
324+
auto kerGpuImages = driver->getGPUObjectsFromAssets<ICPUImage>(kerImageBundle.getContents());
339325

340326
IGPUImageView::SCreationParams kerImgViewInfo;
341327
kerImgViewInfo.flags = static_cast<IGPUImageView::E_CREATE_FLAGS>(0u);
342-
kerImgViewInfo.image = std::move(kerGpuImage);
328+
kerImgViewInfo.image = kerGpuImages->operator[](0u);
343329
kerImgViewInfo.viewType = IGPUImageView::ET_2D;
344-
kerImgViewInfo.format = kerImgInfo.format;
330+
kerImgViewInfo.format = kerImgViewInfo.image->getCreationParameters().format;
345331
kerImgViewInfo.subresourceRange.aspectMask = static_cast<IImage::E_ASPECT_FLAGS>(0u);
346332
kerImgViewInfo.subresourceRange.baseMipLevel = 0;
347333
kerImgViewInfo.subresourceRange.levelCount = 1;
348334
kerImgViewInfo.subresourceRange.baseArrayLayer = 0;
349335
kerImgViewInfo.subresourceRange.layerCount = 1;
350-
kerImageView = driver->createGPUImageView(IGPUImageView::SCreationParams(kerImgViewInfo));
336+
kerImageView = driver->createGPUImageView(std::move(kerImgViewInfo));
351337
}
352338

353339
using FFTClass = ext::FFT::FFT;
354340

355-
E_FORMAT srcFormat = srcImgInfo.format;
356-
E_FORMAT kerFormat = kerImgInfo.format;
357-
VkExtent3D srcDim = srcImgInfo.extent;
358-
VkExtent3D kerDim = kerImgInfo.extent;
341+
const E_FORMAT srcFormat = srcImageView->getCreationParameters().format;
359342
uint32_t srcNumChannels = getFormatChannelCount(srcFormat);
360-
uint32_t kerNumChannels = getFormatChannelCount(kerFormat);
343+
uint32_t kerNumChannels = getFormatChannelCount(kerImageView->getCreationParameters().format);
361344
//! OVERRIDE (we dont need alpha)
362345
srcNumChannels = channelCountOverride;
363346
kerNumChannels = channelCountOverride;
364347
assert(srcNumChannels == kerNumChannels); // Just to make sure, because the other case is not handled in this example
365348

366-
VkExtent3D paddedDim = FFTClass::padDimensionToNextPOT(srcDim, kerDim);
367-
uint32_t maxPaddedDimensionSize = core::max(core::max(paddedDim.width, paddedDim.height), paddedDim.depth);
368-
369-
VkExtent3D outImageDim = srcDim;
349+
const auto srcDim = srcImageView->getCreationParameters().image->getCreationParameters().extent;
370350

371351
// Create Out Image
352+
smart_refctd_ptr<IGPUImage> outImg;
353+
smart_refctd_ptr<IGPUImageView> outImgView;
372354
{
373-
srcImgInfo.extent = outImageDim;
374-
outImg = driver->createDeviceLocalGPUImageOnDedMem(std::move(srcImgInfo));
355+
auto dstImgViewInfo = srcImageView->getCreationParameters();
375356

376-
srcImgViewInfo.image = outImg;
377-
srcImgViewInfo.format = srcImgInfo.format;
378-
outImgView = driver->createGPUImageView(IGPUImageView::SCreationParams(srcImgViewInfo));
379-
}
357+
auto dstImgInfo = dstImgViewInfo.image->getCreationParameters();
358+
outImg = driver->createDeviceLocalGPUImageOnDedMem(std::move(dstImgInfo));
380359

360+
dstImgViewInfo.image = outImg;
361+
outImgView = driver->createGPUImageView(IGPUImageView::SCreationParams(dstImgViewInfo));
362+
}
363+
// TODO: re-examine
364+
const VkExtent3D paddedDim = FFTClass::padDimensionToNextPOT(srcDim);
365+
uint32_t maxPaddedDimensionSize = core::max(core::max(paddedDim.width, paddedDim.height), paddedDim.depth);
381366
auto fftGPUSpecializedShader_SSBOInput = FFTClass::createShader(driver, FFTClass::DataType::SSBO, maxPaddedDimensionSize);
382367
auto fftGPUSpecializedShader_ImageInput = FFTClass::createShader(driver, FFTClass::DataType::TEXTURE2D, maxPaddedDimensionSize);
383368
auto fftGPUSpecializedShader_KernelNormalization = [&]() -> auto
@@ -460,6 +445,7 @@ int main()
460445
kernelNormalizedSpectrums[i] = createKernelSpectrum();
461446

462447
// Precompute Kernel FFT
448+
const auto kerDim = kerImageView->getCreationParameters().image->getCreationParameters().extent;
463449
{
464450
// Ker FFT X
465451
auto fftDescriptorSet_Ker_FFT_X = driver->createGPUDescriptorSet(core::smart_refctd_ptr<const IGPUDescriptorSetLayout>(fftPipelineLayout_ImageInput->getDescriptorSetLayout(0u)));
@@ -610,8 +596,8 @@ int main()
610596
auto& region = regions->front();
611597

612598
region.bufferOffset = 0u;
613-
region.bufferRowLength = srcCpuImg->getRegions().begin()[0].bufferRowLength;
614-
region.bufferImageHeight = srcDim.height;
599+
region.bufferRowLength = 0u;
600+
region.bufferImageHeight = 0u;
615601
//region.imageSubresource.aspectMask = wait for Vulkan;
616602
region.imageSubresource.mipLevel = 0u;
617603
region.imageSubresource.baseArrayLayer = 0u;

include/nbl/ext/FFT/FFT.h

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -96,25 +96,20 @@ class FFT : public core::TotalInterface
9696
}
9797

9898

99-
static inline asset::VkExtent3D padDimensionToNextPOT(asset::VkExtent3D const & dimension, asset::VkExtent3D const & minimum_dimension = asset::VkExtent3D{ 0, 0, 0 }) {
100-
asset::VkExtent3D ret = {};
101-
asset::VkExtent3D extendedDim = dimension;
102-
103-
if(dimension.width < minimum_dimension.width) {
104-
extendedDim.width = minimum_dimension.width;
105-
}
106-
if(dimension.height < minimum_dimension.height) {
107-
extendedDim.height = minimum_dimension.height;
108-
}
109-
if(dimension.depth < minimum_dimension.depth) {
110-
extendedDim.depth = minimum_dimension.depth;
111-
}
112-
113-
ret.width = core::roundUpToPoT(extendedDim.width);
114-
ret.height = core::roundUpToPoT(extendedDim.height);
115-
ret.depth = core::roundUpToPoT(extendedDim.depth);
116-
117-
return ret;
99+
static inline asset::VkExtent3D padDimensionToNextPOT(asset::VkExtent3D dimension, asset::VkExtent3D const & minimum_dimension = asset::VkExtent3D{ 1, 1, 1 })
100+
{
101+
if(dimension.width < minimum_dimension.width)
102+
dimension.width = minimum_dimension.width;
103+
if(dimension.height < minimum_dimension.height)
104+
dimension.height = minimum_dimension.height;
105+
if(dimension.depth < minimum_dimension.depth)
106+
dimension.depth = minimum_dimension.depth;
107+
108+
dimension.width = core::roundUpToPoT(dimension.width);
109+
dimension.height = core::roundUpToPoT(dimension.height);
110+
dimension.depth = core::roundUpToPoT(dimension.depth);
111+
112+
return dimension;
118113
}
119114

120115
//

0 commit comments

Comments
 (0)