diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index 59f9f3880f8..a576a46d5b8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -14,6 +14,7 @@ #define TILE_SIZE_X ${TILE_SIZE_X} #define TILE_SIZE_Y ${TILE_SIZE_Y} +#define LOCAL_WG_SIZE 64 #define op(X, A, B) ${OPERATOR} @@ -42,10 +43,10 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -// shared memory to hold calculated positions, this would reduce register usage thus improving performance. -// 64 is the number of threads in the local wg -$num_shared = 64 * TILE_SIZE_X * TILE_SIZE_Y -shared ivec2 pos_shared[${num_shared}]; +// For performance improvement, reduce register usage by caching positions in shared memory. +// Offset index by 1 every 16 points to avoid bank access conflict. +#define offset_pos_index(index) (index + ((index) >> 4)) +shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y)]; /* * Computes a 2D pointwise convolution of an NxN output tile. Calculating an @@ -54,7 +55,7 @@ shared ivec2 pos_shared[${num_shared}]; */ void main() { const ivec2 out_limits_scaled = (out_limits.xy + ivec2(TILE_SIZE_X - 1, TILE_SIZE_Y - 1)) / ivec2(TILE_SIZE_X, TILE_SIZE_Y); - const uint shared_mem_stride = 64; + const uint shared_mem_stride = LOCAL_WG_SIZE; const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x; const ivec3 gpos = ivec3( @@ -72,7 +73,7 @@ void main() { for (int y = 0, i = 0; y < TILE_SIZE_Y; ++y) { for (int x = 0; x < TILE_SIZE_X; ++x) { pos[i] = ivec2(gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y); - pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i]; + pos_shared[offset_pos_index((shared_mem_stride * i) + gl_LocalInvocationIndex)] = ivec3(pos[i], gpos.z); i++; } } @@ -152,9 +153,10 @@ void main() { } for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex]; - if (all(lessThan(ivec3(pos, gpos.z), out_limits.xyz))) { - imageStore(t_out, ivec3(pos, gpos.z), op(sum[i], out_min, out_max)); + const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex; + const ivec3 pos = pos_shared[offset_pos_index(index)]; + if (all(lessThan(pos, out_limits.xyz))) { + imageStore(t_out, pos, op(sum[i], out_min, out_max)); } } }