Skip to content

Commit ea7214f

Browse files
use push constants for the blur kernel normalization.comp
1 parent 2df8fca commit ea7214f

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

examples_tests/49.ComputeFFT/main.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,11 @@ int main()
404404
FFTClass::updateDescriptorSet(driver, fftDescriptorSet_Ker_FFT_Y.get(), fftOutputBuffer_0, fftOutputBuffer_1);
405405

406406
// Normalization of FFT Y result
407+
struct NormalizationPushConstants
408+
{
409+
ext::FFT::uvec4 stride;
410+
ext::FFT::uvec4 bitreverse_shift;
411+
};
407412
auto fftPipelineLayout_KernelNormalization = [&]() -> auto
408413
{
409414
IGPUDescriptorSetLayout::SBinding bnd[] =
@@ -423,8 +428,12 @@ int main()
423428
nullptr
424429
},
425430
};
431+
SPushConstantRange pc_rng;
432+
pc_rng.offset = 0u;
433+
pc_rng.size = sizeof(NormalizationPushConstants);
434+
pc_rng.stageFlags = ISpecializedShader::ESS_COMPUTE;
426435
return driver->createGPUPipelineLayout(
427-
nullptr,nullptr,
436+
&pc_rng,&pc_rng+1u,
428437
driver->createGPUDescriptorSetLayout(bnd,bnd+2),nullptr,nullptr,nullptr
429438
);
430439
}();
@@ -494,6 +503,14 @@ int main()
494503
);
495504
driver->bindComputePipeline(fftPipeline_KernelNormalization.get());
496505
driver->bindDescriptorSets(EPBP_COMPUTE, fftPipelineLayout_KernelNormalization.get(), 0u, 1u, &fftDescriptorSet_KernelNormalization.get(), nullptr);
506+
{
507+
NormalizationPushConstants normalizationPC;
508+
normalizationPC.stride = {1u,paddedKerDim.width,paddedKerDim.width*paddedKerDim.height,paddedKerDim.width*paddedKerDim.height}; // TODO: take from the Y FFT pass
509+
normalizationPC.bitreverse_shift.x = 32-core::findMSB(paddedKerDim.width);
510+
normalizationPC.bitreverse_shift.y = 32-core::findMSB(paddedKerDim.height);
511+
normalizationPC.bitreverse_shift.z = 0;
512+
driver->pushConstants(fftPipelineLayout_KernelNormalization.get(),ICPUSpecializedShader::ESS_COMPUTE,0u,sizeof(normalizationPC),&normalizationPC);
513+
}
497514
{
498515
const uint32_t dispatchSizeX = (paddedKerDim.width-1u)/16u+1u;
499516
const uint32_t dispatchSizeY = (paddedKerDim.height-1u)/16u+1u;

examples_tests/49.ComputeFFT/normalization.comp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,20 @@ layout(set=0, binding=0) restrict readonly buffer InBuffer
1010

1111
layout(set=0, binding=1, rg16f) uniform image2D NormalizedKernel[3];
1212

13-
void main()
13+
layout(push_constant) uniform PushConstants
1414
{
15-
// TODO: push constants
16-
const uvec2 log2_sizes = findMSB(gl_WorkGroupSize*gl_NumWorkGroups).xy;
17-
const uvec3 strides = uvec3(1u,0x1u<<log2_sizes.x,0x1u<<(log2_sizes.x+log2_sizes.y));
15+
uvec4 strides;
16+
uvec4 bitreverse_shift;
17+
} pc;
1818

19-
const float power = length(in_data[0]);
20-
nbl_glsl_complex value = in_data[gl_GlobalInvocationID.x*strides.x+gl_GlobalInvocationID.y*strides.y+gl_GlobalInvocationID.z*strides.z]/power;
19+
void main()
20+
{
21+
nbl_glsl_complex value = in_data[nbl_glsl_dot(gl_GlobalInvocationID,pc.strides.xyz)];
2122

23+
const float power = length(in_data[0]);
2224

23-
uvec2 coord = bitfieldReverse(gl_GlobalInvocationID.xy)>>(uvec2(32u)-log2_sizes);
24-
const nbl_glsl_complex shift = nbl_glsl_expImaginary(-float(coord.x+coord.y)*nbl_glsl_PI); // TODO: does this shift go away later?
25-
value = nbl_glsl_complex_mul(value,shift);
25+
const uvec2 coord = bitfieldReverse(gl_GlobalInvocationID.xy)>>pc.bitreverse_shift.xy;
26+
const nbl_glsl_complex shift = nbl_glsl_expImaginary(-nbl_glsl_PI*float(coord.x+coord.y));
27+
value = nbl_glsl_complex_mul(value,shift)/power;
2628
imageStore(NormalizedKernel[gl_WorkGroupID.z],ivec2(coord),vec4(value,0.0,0.0));
2729
}

include/nbl/builtin/glsl/math/functions.glsl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77

88
#include <nbl/builtin/glsl/math/constants.glsl>
99

10+
int nbl_glsl_dot(in ivec2 a, in ivec2 b) {return a.x*b.x+a.y*b.y;}
11+
uint nbl_glsl_dot(in uvec2 a, in uvec2 b) {return a.x*b.x+a.y*b.y;}
12+
int nbl_glsl_dot(in ivec3 a, in ivec3 b) {return a.x*b.x+a.y*b.y+a.z*b.z;}
13+
uint nbl_glsl_dot(in uvec3 a, in uvec3 b) {return a.x*b.x+a.y*b.y+a.z*b.z;}
14+
int nbl_glsl_dot(in ivec4 a, in ivec4 b) {return a.x*b.x+a.y*b.y+a.z*b.z+a.w*b.w;}
15+
uint nbl_glsl_dot(in uvec4 a, in uvec4 b) {return a.x*b.x+a.y*b.y+a.z*b.z+a.w*b.w;}
16+
17+
//
1018
float nbl_glsl_erf(in float _x)
1119
{
1220
const float a1 = 0.254829592;

0 commit comments

Comments
 (0)