Skip to content

Commit 2fadaa4

Browse files
account for margin that needs to be added for the kernel
1 parent ccc51cc commit 2fadaa4

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

examples_tests/49.ComputeFFT/last_fft.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ layout(set=0, binding=1, rgba16f) uniform image2D outImage;
88
void nbl_glsl_ext_FFT_setData(in uvec3 coordinate, in uint channel, in nbl_glsl_complex complex_value)
99
{
1010
// TODO PC
11-
const ivec2 padding = imageSize(outImage).x!=512u ? ivec2(384,152):ivec2(0);
11+
const ivec2 padding = imageSize(outImage).x!=512u ? ivec2(384,664):ivec2(0);
1212
const ivec2 coords = ivec2(coordinate.xy)-padding;
1313

1414
if (all(lessThanEqual(ivec2(0),coords))&&all(greaterThan(imageSize(outImage),coords)))

examples_tests/49.ComputeFFT/main.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,19 @@ int main()
342342
);
343343
}();
344344

345+
const auto kerDim = kerImageView->getCreationParameters().image->getCreationParameters().extent;
346+
const auto paddedSrcDim = [srcDim,kerDim]() -> auto
347+
{
348+
auto tmp = srcDim;
349+
tmp.width += kerDim.width-1u;
350+
tmp.height += kerDim.height-1u;
351+
tmp.depth += kerDim.depth-1u;
352+
return tmp;
353+
}();
345354
constexpr bool useHalfFloats = true;
346355
// Allocate Output Buffer
347-
auto fftOutputBuffer_0 = driver->createDeviceLocalGPUBufferOnDedMem(FFTClass::getOutputBufferSize(useHalfFloats,srcDim,srcNumChannels));
348-
auto fftOutputBuffer_1 = driver->createDeviceLocalGPUBufferOnDedMem(FFTClass::getOutputBufferSize(useHalfFloats,srcDim,srcNumChannels));
356+
auto fftOutputBuffer_0 = driver->createDeviceLocalGPUBufferOnDedMem(FFTClass::getOutputBufferSize(useHalfFloats,paddedSrcDim,srcNumChannels));
357+
auto fftOutputBuffer_1 = driver->createDeviceLocalGPUBufferOnDedMem(FFTClass::getOutputBufferSize(useHalfFloats,paddedSrcDim,srcNumChannels));
349358
core::smart_refctd_ptr<IGPUImageView> kernelNormalizedSpectrums[channelCountOverride];
350359

351360
auto updateDescriptorSet = [driver](video::IGPUDescriptorSet* set, core::smart_refctd_ptr<IGPUImageView> inputImageDescriptor, asset::ISampler::E_TEXTURE_CLAMP textureWrap, core::smart_refctd_ptr<IGPUBuffer> outputBufferDescriptor) -> void
@@ -400,7 +409,6 @@ int main()
400409

401410
// Precompute Kernel FFT
402411
{
403-
const auto kerDim = kerImageView->getCreationParameters().image->getCreationParameters().extent;
404412
const VkExtent3D paddedKerDim = FFTClass::padDimensions(kerDim);
405413

406414
// create kernel spectrums
@@ -553,8 +561,8 @@ int main()
553561
}
554562

555563
// pipelines
556-
auto fft_x = core::make_smart_refctd_ptr<FFTClass>(driver,srcDim.width,useHalfFloats);
557-
auto fft_y = core::make_smart_refctd_ptr<FFTClass>(driver,srcDim.height,useHalfFloats);
564+
auto fft_x = core::make_smart_refctd_ptr<FFTClass>(driver,paddedSrcDim.width,useHalfFloats);
565+
auto fft_y = core::make_smart_refctd_ptr<FFTClass>(driver,paddedSrcDim.height,useHalfFloats);
558566
auto fftPipeline_ImageInput = driver->createGPUComputePipeline(nullptr,core::smart_refctd_ptr(imageFirstFFTPipelineLayout),createShader(driver,fft_x.get(), "../image_first_fft.comp"));
559567
auto convolvePipeline = driver->createGPUComputePipeline(nullptr, std::move(convolvePipelineLayout), createShader(driver,fft_y.get(), "../fft_convolve_ifft.comp"));
560568
auto lastFFTPipeline = driver->createGPUComputePipeline(nullptr, getPipelineLayout_LastFFT(driver), createShader(driver,fft_x.get(), "../last_fft.comp"));
@@ -584,13 +592,13 @@ int main()
584592
FFTClass::Parameters_t fftPushConstants[3];
585593
FFTClass::DispatchInfo_t fftDispatchInfo[3];
586594
const ISampler::E_TEXTURE_CLAMP fftPadding[2] = {ISampler::ETC_MIRROR,ISampler::ETC_MIRROR};
587-
const auto passes = FFTClass::buildParameters(false,srcNumChannels,srcDim,fftPushConstants,fftDispatchInfo,fftPadding);
595+
const auto passes = FFTClass::buildParameters(false,srcNumChannels,srcDim,fftPushConstants,fftDispatchInfo,fftPadding,paddedSrcDim);
588596
{
589597
fftPushConstants[1].input_dimensions.x = 2048u;
590598
fftPushConstants[1].output_strides = fftPushConstants[1].input_strides;
591599
fftPushConstants[2] = fftPushConstants[0];
592600
fftPushConstants[2].input_dimensions.x = 2048u;
593-
fftPushConstants[2].input_dimensions.y = 1024u;
601+
fftPushConstants[2].input_dimensions.y = 2048u;
594602
{
595603
fftPushConstants[2].input_dimensions.w ^= 0x80000000u;
596604
fftPushConstants[2].input_dimensions.w &= 0xfffffffdu;

include/nbl/ext/FFT/FFT.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,17 @@ class FFT final : public core::IReferenceCounted
4242
FFT(video::IDriver* driver, uint32_t maxDimensionSize, bool useHalfStorage = false);
4343

4444
// returns how many dispatches necessary for computing the FFT and fills the uniform data
45-
static inline uint32_t buildParameters(bool isInverse, uint32_t numChannels, const asset::VkExtent3D& inputDimensions, Parameters_t* outParams, DispatchInfo_t* outInfos, const asset::ISampler::E_TEXTURE_CLAMP* paddingType)
45+
static inline uint32_t buildParameters(
46+
bool isInverse, uint32_t numChannels, const asset::VkExtent3D& inputDimensions,
47+
Parameters_t* outParams, DispatchInfo_t* outInfos, const asset::ISampler::E_TEXTURE_CLAMP* paddingType,
48+
const asset::VkExtent3D& extraPaddedInputDimensions
49+
)
4650
{
4751
uint32_t passesRequired = 0u;
4852

4953
if (numChannels)
5054
{
51-
const auto paddedInputDimensions = padDimensions(inputDimensions);
55+
const auto paddedInputDimensions = padDimensions(extraPaddedInputDimensions);
5256
for (uint32_t i=0u; i<3u; i++)
5357
if ((&inputDimensions.width)[i]>1u)
5458
{
@@ -87,6 +91,13 @@ class FFT final : public core::IReferenceCounted
8791

8892
return passesRequired;
8993
}
94+
static inline uint32_t buildParameters(
95+
bool isInverse, uint32_t numChannels, const asset::VkExtent3D& inputDimensions,
96+
Parameters_t* outParams, DispatchInfo_t* outInfos, const asset::ISampler::E_TEXTURE_CLAMP* paddingType
97+
)
98+
{
99+
return buildParameters(isInverse,numChannels,inputDimensions,outParams,outInfos,paddingType,inputDimensions);
100+
}
90101

91102
static inline asset::VkExtent3D padDimensions(asset::VkExtent3D dimension)
92103
{

0 commit comments

Comments
 (0)