@@ -34,13 +34,17 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3434
3535#extension  GL_EXT_shader_explicit_arithmetic_types_int16 :  require
3636
37+ //  shared memory to hold calculated positions, this would reduce register usage thus improving performance.
38+ shared  u16vec2 pos_shared[gl_WorkGroupSize.x *  gl_WorkGroupSize.y *  gl_WorkGroupSize.z *  TILE_SIZE *  TILE_SIZE];
39+ 
3740/* 
3841 * Computes a 2D pointwise convolution of an NxN output tile. Calculating an 
3942 * output tile for pointwise convolution is more efficient because the kernel 
4043 * size is only 1x1, making it easier to re-use loaded texels from t_kernel. 
4144 */  
4245void  main() {
4346  const  uint16_t out_limits_y_scaled =  uint16_t((out_limits.y +  TILE_SIZE -  1 ) /  TILE_SIZE);
47+   const  uint  shared_mem_stride =  gl_WorkGroupSize.x *  gl_WorkGroupSize.y *  gl_WorkGroupSize.z;
4448
4549  const  u16vec3 gpos =  u16vec3(
4650    gl_GlobalInvocationID.x /  (out_limits_y_scaled *  out_limits.z),
@@ -58,6 +62,7 @@ void main() {
5862    for  (int  x =  0 ; x <  TILE_SIZE; ++ x) {
5963      pos[i] =  u16vec2(
6064          gpos.x *  TILE_SIZE +  x, gpos.y *  TILE_SIZE +  y);
65+       pos_shared[(shared_mem_stride *  i) +  gl_LocalInvocationIndex] =  pos[i];
6166      i++ ;
6267    }
6368  }
@@ -73,7 +78,7 @@ void main() {
7378  //  the top-left element is in a region added by padding.
7479  u16vec2 ipos[TILE_SIZE *  TILE_SIZE];
7580  for  (int  i =  0 ; i <  TILE_SIZE *  TILE_SIZE; ++ i) {
76-     ipos[i] =  pos[i].xy  *  u16vec2(stride) -  u16vec2(padding);
81+     ipos[i] =  pos[i] *  u16vec2(stride) -  u16vec2(padding);
7782  }
7883
7984  vec4  sum[TILE_SIZE *  TILE_SIZE];
@@ -138,8 +143,9 @@ void main() {
138143  }
139144
140145  for  (int  i =  0 ; i <  TILE_SIZE *  TILE_SIZE; ++ i) {
141-     if  (all (lessThan (u16vec3(pos[i], gpos.z), out_limits))) {
142-       imageStore(t_out, u16vec3(pos[i], gpos.z), op(sum[i], out_min, out_max));
146+     const  u16vec2 pos =  pos_shared[(shared_mem_stride *  i) +  gl_LocalInvocationIndex];
147+     if  (all (lessThan (u16vec3(pos, gpos.z), out_limits))) {
148+       imageStore(t_out, u16vec3(pos, gpos.z), op(sum[i], out_min, out_max));
143149    }
144150  }
145151}
0 commit comments