1414
1515#define TILE_SIZE_X ${TILE_SIZE_X}
1616#define TILE_SIZE_Y ${TILE_SIZE_Y}
17+ #define LOCAL_WG_SIZE 64
1718
1819#define op(X, A, B) ${OPERATOR}
1920
@@ -42,10 +43,10 @@ layout(push_constant) uniform restrict Block {
4243
4344layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4445
45- // shared memory to hold calculated positions, this would reduce register usage thus improving performance .
46- // 64 is the number of threads in the local wg
47- $num_shared = 64 * TILE_SIZE_X * TILE_SIZE_Y
48- shared ivec2 pos_shared[${num_shared} ];
46+ // For performance improvement, reduce register usage by caching positions in shared memory .
47+ // Offset index by 1 every 16 points to avoid bank access conflict.
48+ #define offset_pos_index(index) (index + ((index) >> 4 ))
49+ shared ivec3 pos_shared[offset_pos_index(LOCAL_WG_SIZE * TILE_SIZE_X * TILE_SIZE_Y) ];
4950
5051/*
5152 * Computes a 2D pointwise convolution of an NxN output tile. Calculating an
@@ -54,7 +55,7 @@ shared ivec2 pos_shared[${num_shared}];
5455 */
5556void main() {
5657 const ivec2 out_limits_scaled = (out_limits.xy + ivec2 (TILE_SIZE_X - 1 , TILE_SIZE_Y - 1 )) / ivec2 (TILE_SIZE_X, TILE_SIZE_Y);
57- const uint shared_mem_stride = 64 ;
58+ const uint shared_mem_stride = LOCAL_WG_SIZE ;
5859
5960 const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
6061 const ivec3 gpos = ivec3 (
@@ -72,7 +73,7 @@ void main() {
7273 for (int y = 0 , i = 0 ; y < TILE_SIZE_Y; ++ y) {
7374 for (int x = 0 ; x < TILE_SIZE_X; ++ x) {
7475 pos[i] = ivec2 (gpos.x * TILE_SIZE_X + x, gpos.y * TILE_SIZE_Y + y);
75- pos_shared[( shared_mem_stride * i) + gl_LocalInvocationIndex] = pos[i];
76+ pos_shared[offset_pos_index(( shared_mem_stride * i) + gl_LocalInvocationIndex) ] = ivec3 ( pos[i], gpos.z) ;
7677 i++ ;
7778 }
7879 }
@@ -152,9 +153,10 @@ void main() {
152153 }
153154
154155 for (int i = 0 ; i < TILE_SIZE_X * TILE_SIZE_Y; ++ i) {
155- const ivec2 pos = pos_shared[(shared_mem_stride * i) + gl_LocalInvocationIndex];
156- if (all (lessThan (ivec3 (pos, gpos.z), out_limits.xyz))) {
157- imageStore(t_out, ivec3 (pos, gpos.z), op(sum[i], out_min, out_max));
156+ const uint index = (shared_mem_stride * i) + gl_LocalInvocationIndex;
157+ const ivec3 pos = pos_shared[offset_pos_index(index)];
158+ if (all (lessThan (pos, out_limits.xyz))) {
159+ imageStore(t_out, pos, op(sum[i], out_min, out_max));
158160 }
159161 }
160162}
0 commit comments