Skip to content

Commit 7dc7ddc

Browse files
committed
compute each channel in local mem of physical threads
(now the number of workgroups issued is 1/4th of the previous version)
1 parent f8ee656 commit 7dc7ddc

File tree

6 files changed

+171
-151
lines changed

6 files changed

+171
-151
lines changed

examples_tests/49.ComputeFFT/fft_convolve_ifft.comp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,22 @@ nbl_glsl_complex nbl_glsl_ext_FFT_getPaddedData(in uvec3 coordinate, in uint cha
6969

7070
void convolve()
7171
{
72-
uint channel = nbl_glsl_ext_FFT_getChannel();
72+
uint numChannels = nbl_glsl_ext_FFT_Parameters_t_getNumChannels();
7373
uvec3 dimension = nbl_glsl_ext_FFT_Parameters_t_getDimensions();
7474
uint dataLength = nbl_glsl_ext_FFT_getDimLength(nbl_glsl_ext_FFT_Parameters_t_getPaddedDimensions());
7575

7676
uint thread_offset = gl_LocalInvocationIndex;
7777
uint num_virtual_threads = (dataLength-1u)/(_NBL_GLSL_WORKGROUP_SIZE_)+1u;
7878

79-
for(uint t = 0u; t < num_virtual_threads; t++)
80-
{
81-
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
82-
uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
83-
uint idx = channel * (dimension.x * dimension.y * dimension.z) + coords.z * (dimension.x * dimension.y) + coords.y * (dimension.x) + coords.x;
84-
nbl_glsl_complex temp = inoutData[idx];
85-
inoutData[idx] = nbl_glsl_complex_mul(temp, kerData[idx]);
79+
for(uint ch = 0u; ch < numChannels; ++ch) {
80+
for(uint t = 0u; t < num_virtual_threads; t++)
81+
{
82+
uint tid = thread_offset + t * _NBL_GLSL_EXT_FFT_WORKGROUP_SIZE_;
83+
uvec3 coords = nbl_glsl_ext_FFT_getCoordinates(tid);
84+
uint idx = ch * (dimension.x * dimension.y * dimension.z) + coords.z * (dimension.x * dimension.y) + coords.y * (dimension.x) + coords.x;
85+
nbl_glsl_complex temp = inoutData[idx];
86+
inoutData[idx] = nbl_glsl_complex_mul(temp, kerData[idx]);
87+
}
8688
}
8789
}
8890

examples_tests/49.ComputeFFT/main.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,8 @@ int main()
390390
auto fftPipeline_ImageInput = driver->createGPUComputePipeline(nullptr, core::smart_refctd_ptr(fftPipelineLayout_ImageInput), std::move(fftGPUSpecializedShader_ImageInput));
391391
auto fftPipeline_KernelNormalization = driver->createGPUComputePipeline(nullptr, core::smart_refctd_ptr(fftPipelineLayout_KernelNormalization), std::move(fftGPUSpecializedShader_KernelNormalization));
392392

393-
auto fftDispatchInfo_Horizontal = FFTClass::buildParameters(paddedDim, FFTClass::Direction::X, srcNumChannels);
394-
auto fftDispatchInfo_Vertical = FFTClass::buildParameters(paddedDim, FFTClass::Direction::Y, srcNumChannels);
393+
auto fftDispatchInfo_Horizontal = FFTClass::buildParameters(paddedDim, FFTClass::Direction::X);
394+
auto fftDispatchInfo_Vertical = FFTClass::buildParameters(paddedDim, FFTClass::Direction::Y);
395395

396396
auto convolveShader = createShader_Convolution(driver, am, maxPaddedDimensionSize);
397397
auto convolvePipelineLayout = getPipelineLayout_Convolution(driver);
@@ -420,19 +420,19 @@ int main()
420420
// Ker Image FFT X
421421
driver->bindComputePipeline(fftPipeline_ImageInput.get());
422422
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_ImageInput.get(), 0u, 1u, &fftDescriptorSet_Ker_FFT_X.get(), nullptr);
423-
FFTClass::pushConstants(driver, fftPipelineLayout_ImageInput.get(), kerDim, paddedDim, FFTClass::Direction::X, false, FFTClass::PaddingType::FILL_WITH_ZERO);
423+
FFTClass::pushConstants(driver, fftPipelineLayout_ImageInput.get(), kerDim, paddedDim, FFTClass::Direction::X, false, srcNumChannels, FFTClass::PaddingType::FILL_WITH_ZERO);
424424
FFTClass::dispatchHelper(driver, fftDispatchInfo_Horizontal);
425425

426426
// Ker Image FFT Y
427427
driver->bindComputePipeline(fftPipeline_SSBOInput.get());
428428
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_SSBOInput.get(), 0u, 1u, &fftDescriptorSet_Ker_FFT_Y.get(), nullptr);
429-
FFTClass::pushConstants(driver, fftPipelineLayout_SSBOInput.get(), paddedDim, paddedDim, FFTClass::Direction::Y, false);
429+
FFTClass::pushConstants(driver, fftPipelineLayout_SSBOInput.get(), paddedDim, paddedDim, FFTClass::Direction::Y, false, srcNumChannels);
430430
FFTClass::dispatchHelper(driver, fftDispatchInfo_Vertical);
431431

432432
// Ker Image FFT Y
433433
driver->bindComputePipeline(fftPipeline_SSBOInput.get());
434434
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_SSBOInput.get(), 0u, 1u, &fftDescriptorSet_Ker_FFT_Y.get(), nullptr);
435-
FFTClass::pushConstants(driver, fftPipelineLayout_SSBOInput.get(), paddedDim, paddedDim, FFTClass::Direction::Y, false);
435+
FFTClass::pushConstants(driver, fftPipelineLayout_SSBOInput.get(), paddedDim, paddedDim, FFTClass::Direction::Y, false, srcNumChannels);
436436
FFTClass::dispatchHelper(driver, fftDispatchInfo_Vertical);
437437

438438
// Ker Normalization
@@ -476,19 +476,19 @@ int main()
476476
// Src Image FFT X
477477
driver->bindComputePipeline(fftPipeline_ImageInput.get());
478478
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_ImageInput.get(), 0u, 1u, &fftDescriptorSet_Src_FFT_X.get(), nullptr);
479-
FFTClass::pushConstants(driver, fftPipelineLayout_ImageInput.get(), srcDim, paddedDim, FFTClass::Direction::X, false, FFTClass::PaddingType::CLAMP_TO_EDGE);
479+
FFTClass::pushConstants(driver, fftPipelineLayout_ImageInput.get(), srcDim, paddedDim, FFTClass::Direction::X, false, srcNumChannels, FFTClass::PaddingType::CLAMP_TO_EDGE);
480480
FFTClass::dispatchHelper(driver, fftDispatchInfo_Horizontal);
481481

482482
// Src Image FFT Y + Convolution + Convolved IFFT Y
483483
driver->bindComputePipeline(convolvePipeline.get());
484484
driver->bindDescriptorSets(EPBP_COMPUTE, convolvePipelineLayout.get(), 0u, 1u, &convolveDescriptorSet.get(), nullptr);
485-
FFTClass::pushConstants(driver, convolvePipelineLayout.get(), paddedDim, paddedDim, FFTClass::Direction::Y, false);
485+
FFTClass::pushConstants(driver, convolvePipelineLayout.get(), paddedDim, paddedDim, FFTClass::Direction::Y, false, srcNumChannels);
486486
FFTClass::dispatchHelper(driver, fftDispatchInfo_Vertical);
487487

488488
// Convolved IFFT X
489489
driver->bindComputePipeline(fftPipeline_SSBOInput.get());
490490
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_SSBOInput.get(), 0u, 1u, &fftDescriptorSet_IFFT_X.get(), nullptr);
491-
FFTClass::pushConstants(driver, fftPipelineLayout_SSBOInput.get(), paddedDim, paddedDim, FFTClass::Direction::X, true);
491+
FFTClass::pushConstants(driver, fftPipelineLayout_SSBOInput.get(), paddedDim, paddedDim, FFTClass::Direction::X, true, srcNumChannels);
492492
FFTClass::dispatchHelper(driver, fftDispatchInfo_Horizontal);
493493

494494
// Remove Padding and Copy to GPU Image

0 commit comments

Comments
 (0)