1414
1515#define  TILE_SIZE_X ${TILE_SIZE_X}
1616#define  TILE_SIZE_Y ${TILE_SIZE_Y}
17- #define  LOCAL_WG_SIZE 64 
1817
1918#define  op(X, A, B) ${OPERATOR}
2019
@@ -39,53 +38,46 @@ layout(push_constant) uniform restrict Block {
3938
4039layout (local_size_x_id =  0 , local_size_y_id =  1 , local_size_z_id =  2 ) in ;
4140
42- //  For performance improvement, reduce register usage by caching positions in shared memory.
43- //  Offset index by 1 every 16 points to avoid bank access conflict.
44- #define  offset_pos_index(index) (index +  ((index) >>  4 ))
45- shared  ivec3  pos_shared[offset_pos_index(LOCAL_WG_SIZE *  TILE_SIZE_X *  TILE_SIZE_Y)];
46- 
4741/* 
4842 * Computes a 2D pointwise convolution of an NxN output tile. Calculating an 
4943 * output tile for pointwise convolution is more efficient because the kernel 
5044 * size is only 1x1, making it easier to re-use loaded texels from t_kernel. 
5145 */  
5246void  main() {
5347  const  ivec2  out_limits_scaled =  (out_limits.xy +  ivec2 (TILE_SIZE_X -  1 , TILE_SIZE_Y -  1 )) /  ivec2 (TILE_SIZE_X, TILE_SIZE_Y);
54-   const  uint  shared_mem_stride =  LOCAL_WG_SIZE;
5548
5649  const  uint  div_by_x =  gl_GlobalInvocationID.x /  out_limits_scaled.x;
5750  const  ivec3  gpos =  ivec3 (
5851    gl_GlobalInvocationID.x %  out_limits_scaled.x,
5952    div_by_x %  out_limits_scaled.y,
6053    div_by_x /  out_limits_scaled.y);
6154
55+   //  If the top left position is out of bounds, then this invocation will have
56+   //  no work to do.
57+   if  (gpos.z >=  out_limits.z) {
58+     return ;
59+   }
60+ 
6261  //  Output position for TILE_SIZE = 2
6362  //  +--------+--------+
6463  //  | pos[0] | pos[1] |
6564  //  +--------+--------+
6665  //  | pos[2] | pos[3] |
6766  //  +--------+--------+
68-   ivec2  pos[TILE_SIZE_X *  TILE_SIZE_Y];
67+   ivec3  pos[TILE_SIZE_X *  TILE_SIZE_Y];
6968  for  (int  y =  0 , i =  0 ; y <  TILE_SIZE_Y; ++ y) {
7069    for  (int  x =  0 ; x <  TILE_SIZE_X; ++ x) {
71-       pos[i] =  ivec2 (gpos.x *  TILE_SIZE_X +  x, gpos.y *  TILE_SIZE_Y +  y);
72-       pos_shared[offset_pos_index((shared_mem_stride *  i) +  gl_LocalInvocationIndex)] =  ivec3 (pos[i], gpos.z);
70+       pos[i] =  ivec3 (gpos.x *  TILE_SIZE_X +  x, gpos.y *  TILE_SIZE_Y +  y, gpos.z);
7371      i++ ;
7472    }
7573  }
7674
77-   //  If the top left position is out of bounds, then this invocation will have
78-   //  no work to do.
79-   if  (gpos.z >=  out_limits.z) {
80-     return ;
81-   }
82- 
8375  //  Compute the index of the input texture that needs to be loaded for each
8476  //  output position. Note that negative indices can be produced indicating that
8577  //  the top-left element is in a region added by padding.
8678  ivec2  ipos[TILE_SIZE_X *  TILE_SIZE_Y];
8779  for  (int  i =  0 ; i <  TILE_SIZE_X *  TILE_SIZE_Y; ++ i) {
88-     ipos[i] =  pos[i] *  stride -  padding;
80+     ipos[i] =  pos[i].xy  *  stride -  padding;
8981  }
9082
9183  //  Final output array where each element is a tensor value.
@@ -171,10 +163,8 @@ void main() {
171163  }
172164
173165  for  (int  i =  0 ; i <  TILE_SIZE_X *  TILE_SIZE_Y; ++ i) {
174-     const  uint  index =  (shared_mem_stride *  i) +  gl_LocalInvocationIndex;
175-     const  ivec3  pos =  pos_shared[offset_pos_index(index)];
176-     if  (all (lessThan (pos, out_limits.xyz))) {
177-       imageStore(t_out, pos, op(vec4 (sum[i *  4 ], sum[i *  4  +  1 ], sum[i *  4  +  2 ], sum[i *  4  +  3 ]), out_min, out_max));
166+     if  (all (lessThan (pos[i], out_limits.xyz))) {
167+       imageStore(t_out, pos[i], op(vec4 (sum[i *  4 ], sum[i *  4  +  1 ], sum[i *  4  +  2 ], sum[i *  4  +  3 ]), out_min, out_max));
178168    }
179169  }
180170}
0 commit comments