@@ -390,8 +390,8 @@ int main()
390
390
auto fftPipeline_ImageInput = driver->createGPUComputePipeline (nullptr , core::smart_refctd_ptr (fftPipelineLayout_ImageInput), std::move (fftGPUSpecializedShader_ImageInput));
391
391
auto fftPipeline_KernelNormalization = driver->createGPUComputePipeline (nullptr , core::smart_refctd_ptr (fftPipelineLayout_KernelNormalization), std::move (fftGPUSpecializedShader_KernelNormalization));
392
392
393
- auto fftDispatchInfo_Horizontal = FFTClass::buildParameters (paddedDim, FFTClass::Direction::X, srcNumChannels );
394
- auto fftDispatchInfo_Vertical = FFTClass::buildParameters (paddedDim, FFTClass::Direction::Y, srcNumChannels );
393
+ auto fftDispatchInfo_Horizontal = FFTClass::buildParameters (paddedDim, FFTClass::Direction::X);
394
+ auto fftDispatchInfo_Vertical = FFTClass::buildParameters (paddedDim, FFTClass::Direction::Y);
395
395
396
396
auto convolveShader = createShader_Convolution (driver, am, maxPaddedDimensionSize);
397
397
auto convolvePipelineLayout = getPipelineLayout_Convolution (driver);
@@ -420,19 +420,19 @@ int main()
420
420
// Ker Image FFT X
421
421
driver->bindComputePipeline (fftPipeline_ImageInput.get ());
422
422
driver->bindDescriptorSets (EPBP_COMPUTE, fftPipelineLayout_ImageInput.get (), 0u , 1u , &fftDescriptorSet_Ker_FFT_X.get (), nullptr );
423
- FFTClass::pushConstants (driver, fftPipelineLayout_ImageInput.get (), kerDim, paddedDim, FFTClass::Direction::X, false , FFTClass::PaddingType::FILL_WITH_ZERO);
423
+ FFTClass::pushConstants (driver, fftPipelineLayout_ImageInput.get (), kerDim, paddedDim, FFTClass::Direction::X, false , srcNumChannels, FFTClass::PaddingType::FILL_WITH_ZERO);
424
424
FFTClass::dispatchHelper (driver, fftDispatchInfo_Horizontal);
425
425
426
426
// Ker Image FFT Y
427
427
driver->bindComputePipeline (fftPipeline_SSBOInput.get ());
428
428
driver->bindDescriptorSets (EPBP_COMPUTE, fftPipelineLayout_SSBOInput.get (), 0u , 1u , &fftDescriptorSet_Ker_FFT_Y.get (), nullptr );
429
- FFTClass::pushConstants (driver, fftPipelineLayout_SSBOInput.get (), paddedDim, paddedDim, FFTClass::Direction::Y, false );
429
+ FFTClass::pushConstants (driver, fftPipelineLayout_SSBOInput.get (), paddedDim, paddedDim, FFTClass::Direction::Y, false , srcNumChannels );
430
430
FFTClass::dispatchHelper (driver, fftDispatchInfo_Vertical);
431
431
432
432
// Ker Image FFT Y
433
433
driver->bindComputePipeline (fftPipeline_SSBOInput.get ());
434
434
driver->bindDescriptorSets (EPBP_COMPUTE, fftPipelineLayout_SSBOInput.get (), 0u , 1u , &fftDescriptorSet_Ker_FFT_Y.get (), nullptr );
435
- FFTClass::pushConstants (driver, fftPipelineLayout_SSBOInput.get (), paddedDim, paddedDim, FFTClass::Direction::Y, false );
435
+ FFTClass::pushConstants (driver, fftPipelineLayout_SSBOInput.get (), paddedDim, paddedDim, FFTClass::Direction::Y, false , srcNumChannels );
436
436
FFTClass::dispatchHelper (driver, fftDispatchInfo_Vertical);
437
437
438
438
// Ker Normalization
@@ -476,19 +476,19 @@ int main()
476
476
// Src Image FFT X
477
477
driver->bindComputePipeline (fftPipeline_ImageInput.get ());
478
478
driver->bindDescriptorSets (EPBP_COMPUTE, fftPipelineLayout_ImageInput.get (), 0u , 1u , &fftDescriptorSet_Src_FFT_X.get (), nullptr );
479
- FFTClass::pushConstants (driver, fftPipelineLayout_ImageInput.get (), srcDim, paddedDim, FFTClass::Direction::X, false , FFTClass::PaddingType::CLAMP_TO_EDGE);
479
+ FFTClass::pushConstants (driver, fftPipelineLayout_ImageInput.get (), srcDim, paddedDim, FFTClass::Direction::X, false , srcNumChannels, FFTClass::PaddingType::CLAMP_TO_EDGE);
480
480
FFTClass::dispatchHelper (driver, fftDispatchInfo_Horizontal);
481
481
482
482
// Src Image FFT Y + Convolution + Convolved IFFT Y
483
483
driver->bindComputePipeline (convolvePipeline.get ());
484
484
driver->bindDescriptorSets (EPBP_COMPUTE, convolvePipelineLayout.get (), 0u , 1u , &convolveDescriptorSet.get (), nullptr );
485
- FFTClass::pushConstants (driver, convolvePipelineLayout.get (), paddedDim, paddedDim, FFTClass::Direction::Y, false );
485
+ FFTClass::pushConstants (driver, convolvePipelineLayout.get (), paddedDim, paddedDim, FFTClass::Direction::Y, false , srcNumChannels );
486
486
FFTClass::dispatchHelper (driver, fftDispatchInfo_Vertical);
487
487
488
488
// Convolved IFFT X
489
489
driver->bindComputePipeline (fftPipeline_SSBOInput.get ());
490
490
driver->bindDescriptorSets (EPBP_COMPUTE, fftPipelineLayout_SSBOInput.get (), 0u , 1u , &fftDescriptorSet_IFFT_X.get (), nullptr );
491
- FFTClass::pushConstants (driver, fftPipelineLayout_SSBOInput.get (), paddedDim, paddedDim, FFTClass::Direction::X, true );
491
+ FFTClass::pushConstants (driver, fftPipelineLayout_SSBOInput.get (), paddedDim, paddedDim, FFTClass::Direction::X, true , srcNumChannels );
492
492
FFTClass::dispatchHelper (driver, fftDispatchInfo_Horizontal);
493
493
494
494
// Remove Padding and Copy to GPU Image
0 commit comments