@@ -31,10 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block {
3131layout (local_size_x_id =  0 , local_size_y_id =  1 , local_size_z_id =  2 ) in ;
3232layout (constant_id =  3 ) const  int  packed_dim =  C_DIM;
3333
34- #extension  GL_EXT_shader_explicit_arithmetic_types_int16 :  require
35- 
3634void  main() {
37-   u16vec3  pos =  u16vec3 (gl_GlobalInvocationID);
35+   ivec3  pos =  ivec3 (gl_GlobalInvocationID);
3836
3937  if  (any (greaterThanEqual (pos, out_limits.xyz))) {
4038    return ;
@@ -48,34 +46,40 @@ void main() {
4846  //  index of packed dim in bchw format
4947  const  int  in_packed_dim_bchw_index =  3  -  packed_dim;
5048
51-   for  (int  j =  0 ; j <  4 ; ++ j, pos[packed_dim]++ ) {
52-     ivec4  in_bchw_pos =  ivec4 (0 ); //  holds b,c,h,w
53-     //  determine input position based on output position and permute map
54-     //  out_ndims is in BCHW format
55-     in_bchw_pos[out_ndims[0 ]] =  (pos.z /  channel_info.x);
56-     in_bchw_pos[out_ndims[1 ]] =  (pos.z %  channel_info.x);
57-     in_bchw_pos[out_ndims[2 ]] =  pos.y;
58-     in_bchw_pos[out_ndims[3 ]] =  pos.x;
49+   //  determine input position based on output position and permute map
50+   //  out_ndims is in BCHW format
51+   ivec4  in_bchw_pos =  ivec4 (0 ); //  holds b,c,h,w
52+   in_bchw_pos[out_ndims[0 ]] =  (pos.z /  channel_info.x);
53+   in_bchw_pos[out_ndims[1 ]] =  (pos.z %  channel_info.x);
54+   in_bchw_pos[out_ndims[2 ]] =  pos.y;
55+   in_bchw_pos[out_ndims[3 ]] =  pos.x;
5956
57+   for  (int  j =  0 ; j <  4 ; ++ j) {
58+     //  terminate the loop if trying to access input texture out of bounds
6059    if  (any (greaterThanEqual (in_bchw_pos.wzyx, in_sizes.xyzw))) {
6160      break ;
6261    }
62+     ivec3  fetch_pos;
6363
64-     //  input tensor's packed dim pos (in xyz format) corresponding to output tensor's pos (which is also in xyz format)
65-     const  int  in_packed_dim_pos =  in_bchw_pos[in_packed_dim_bchw_index];
64+     fetch_pos.xy =  in_bchw_pos.wz;
65+     //  calculate input position in z axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
66+     fetch_pos.z =  in_bchw_pos.y +  in_bchw_pos.x *  channel_info.y;
6667
67-     //  calculate  input position in y axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively 
68-     in_bchw_pos.y  =  in_bchw_pos.y  +  in_bchw_pos.x  *  channel_info.y ;
68+     //  input tensor's packed dim lane corresponding to output tensor's pos 
69+     const   int  in_packed_dim_lane_index  =  fetch_pos[packed_dim]  &  0x3 ;
6970
7071    //  scale down input tensor's packed dim pos to perform fetch
71-     in_bchw_pos[in_packed_dim_bchw_index ] >>=  2 ;
72+     fetch_pos[packed_dim ] >>=  2 ;
7273
7374    //  fetch input texel
74-     VEC4_T inval =  VEC4_T(texelFetch(image_in, u16vec3(in_bchw_pos.wzy), 0 ));
75-     outval[j] =  inval[in_packed_dim_pos &  0x3];
75+     VEC4_T inval =  VEC4_T(texelFetch(image_in, fetch_pos, 0 ));
76+     outval[j] =  inval[in_packed_dim_lane_index];
77+ 
78+     //  go to next position in the input, that is mapped to the packed dim in the output
79+     in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++ ;
7680  }
7781
78-   pos[packed_dim] =  uint16_t (gl_GlobalInvocationID[packed_dim]);
82+   pos[packed_dim] =  int (gl_GlobalInvocationID[packed_dim]);
7983
8084  imageStore(image_out, pos, outval);
8185}
0 commit comments