@@ -32,24 +32,26 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3232
3333layout (local_size_x_id =  0 , local_size_y_id =  1 , local_size_z_id =  2 ) in ;
3434
35+ #extension  GL_EXT_shader_explicit_arithmetic_types_int16 :  require
36+ 
3537/* 
3638 * Computes a 2D pointwise convolution of an NxN output tile. Calculating an 
3739 * output tile for pointwise convolution is more efficient because the kernel 
3840 * size is only 1x1, making it easier to re-use loaded texels from t_kernel. 
3941 */  
4042void  main() {
41-   const  ivec3  gpos =  ivec3 (gl_GlobalInvocationID);
43+   const  u16vec3  gpos =  u16vec3 (gl_GlobalInvocationID);
4244
4345  //  Output position for TILE_SIZE = 2
4446  //  +--------+--------+
4547  //  | pos[0] | pos[1] |
4648  //  +--------+--------+
4749  //  | pos[2] | pos[3] |
4850  //  +--------+--------+
49-   ivec3  pos[TILE_SIZE *  TILE_SIZE];
51+   u16vec3  pos[TILE_SIZE *  TILE_SIZE];
5052  for  (int  y =  0 , i =  0 ; y <  TILE_SIZE; ++ y) {
5153    for  (int  x =  0 ; x <  TILE_SIZE; ++ x) {
52-       pos[i] =  ivec3 (
54+       pos[i] =  u16vec3 (
5355          gpos.x *  TILE_SIZE +  x, gpos.y *  TILE_SIZE +  y, gpos.z);
5456      i++ ;
5557    }
@@ -64,13 +66,13 @@ void main() {
6466  //  Compute the index of the input texture that needs to be loaded for each
6567  //  output position. Note that negative indices can be produced indicating that
6668  //  the top-left element is in a region added by padding.
67-   ivec2  ipos[TILE_SIZE *  TILE_SIZE];
69+   u16vec2  ipos[TILE_SIZE *  TILE_SIZE];
6870  for  (int  i =  0 ; i <  TILE_SIZE *  TILE_SIZE; ++ i) {
69-     ipos[i] =  pos[i].xy *  stride -  padding;
71+     ipos[i] =  pos[i].xy *  u16vec2( stride)  -  u16vec2( padding) ;
7072  }
7173
7274  vec4  sum[TILE_SIZE *  TILE_SIZE];
73-   sum[0 ] =  texelFetch(t_bias, ivec2 (gpos.z, 0 ), 0 );
75+   sum[0 ] =  texelFetch(t_bias, u16vec2 (gpos.z, 0 ), 0 );
7476  for  (int  i =  1 ; i <  TILE_SIZE *  TILE_SIZE; ++ i) {
7577    sum[i] =  sum[0 ];
7678  }
@@ -81,13 +83,13 @@ void main() {
8183    //  channel (IC) dim is along the x-axis, and the batch (OC) dim is along
8284    //  the z-axis.
8385    vec4  in_tex[TILE_SIZE *  TILE_SIZE];
84-     const  vec4  ktex_0 =  texelFetch(t_kernel, ivec2 (z +  0 , gpos.z), 0 );
85-     const  vec4  ktex_1 =  texelFetch(t_kernel, ivec2 (z +  1 , gpos.z), 0 );
86-     const  vec4  ktex_2 =  texelFetch(t_kernel, ivec2 (z +  2 , gpos.z), 0 );
87-     const  vec4  ktex_3 =  texelFetch(t_kernel, ivec2 (z +  3 , gpos.z), 0 );
86+     const  vec4  ktex_0 =  texelFetch(t_kernel, u16vec2 (z +  0 , gpos.z), 0 );
87+     const  vec4  ktex_1 =  texelFetch(t_kernel, u16vec2 (z +  1 , gpos.z), 0 );
88+     const  vec4  ktex_2 =  texelFetch(t_kernel, u16vec2 (z +  2 , gpos.z), 0 );
89+     const  vec4  ktex_3 =  texelFetch(t_kernel, u16vec2 (z +  3 , gpos.z), 0 );
8890
8991    for  (int  i =  0 ; i <  TILE_SIZE *  TILE_SIZE; ++ i) {
90-       in_tex[i] =  texelFetch(t_in, ivec3 (ipos[i], z4), 0 );
92+       in_tex[i] =  texelFetch(t_in, u16vec3 (ipos[i], z4), 0 );
9193    }
9294
9395    for  (int  i =  0 ; i <  TILE_SIZE *  TILE_SIZE; ++ i) {
0 commit comments