@@ -31,10 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block {
3131layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
3232layout (constant_id = 3 ) const int packed_dim = C_DIM;
3333
34- #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
35-
3634void main() {
37- u16vec3 pos = u16vec3 (gl_GlobalInvocationID);
35+ ivec3 pos = ivec3 (gl_GlobalInvocationID);
3836
3937 if (any (greaterThanEqual (pos, out_limits.xyz))) {
4038 return ;
@@ -48,34 +46,40 @@ void main() {
4846 // index of packed dim in bchw format
4947 const int in_packed_dim_bchw_index = 3 - packed_dim;
5048
51- for (int j = 0 ; j < 4 ; ++ j, pos[packed_dim]++ ) {
52- ivec4 in_bchw_pos = ivec4 (0 ); // holds b,c,h,w
53- // determine input position based on output position and permute map
54- // out_ndims is in BCHW format
55- in_bchw_pos[out_ndims[0 ]] = (pos.z / channel_info.x);
56- in_bchw_pos[out_ndims[1 ]] = (pos.z % channel_info.x);
57- in_bchw_pos[out_ndims[2 ]] = pos.y;
58- in_bchw_pos[out_ndims[3 ]] = pos.x;
49+ // determine input position based on output position and permute map
50+ // out_ndims is in BCHW format
51+ ivec4 in_bchw_pos = ivec4 (0 ); // holds b,c,h,w
52+ in_bchw_pos[out_ndims[0 ]] = (pos.z / channel_info.x);
53+ in_bchw_pos[out_ndims[1 ]] = (pos.z % channel_info.x);
54+ in_bchw_pos[out_ndims[2 ]] = pos.y;
55+ in_bchw_pos[out_ndims[3 ]] = pos.x;
5956
57+ for (int j = 0 ; j < 4 ; ++ j) {
58+ // terminate the loop if trying to access input texture out of bounds
6059 if (any (greaterThanEqual (in_bchw_pos.wzyx, in_sizes.xyzw))) {
6160 break ;
6261 }
62+ ivec3 fetch_pos;
6363
64- // input tensor's packed dim pos (in xyz format) corresponding to output tensor's pos (which is also in xyz format)
65- const int in_packed_dim_pos = in_bchw_pos[in_packed_dim_bchw_index];
64+ fetch_pos.xy = in_bchw_pos.wz;
65+ // calculate input position in z axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
66+ fetch_pos.z = in_bchw_pos.y + in_bchw_pos.x * channel_info.y;
6667
67- // calculate input position in y axis using batch and channel index which is in_bchw_pos.x and in_bchw_pos.y respectively
68- in_bchw_pos.y = in_bchw_pos.y + in_bchw_pos.x * channel_info.y ;
68+ // input tensor's packed dim lane corresponding to output tensor's pos
69+ const int in_packed_dim_lane_index = fetch_pos[packed_dim] & 0x3 ;
6970
7071 // scale down input tensor's packed dim pos to perform fetch
71- in_bchw_pos[in_packed_dim_bchw_index ] >>= 2 ;
72+ fetch_pos[packed_dim ] >>= 2 ;
7273
7374 // fetch input texel
74- VEC4_T inval = VEC4_T(texelFetch(image_in, u16vec3(in_bchw_pos.wzy), 0 ));
75- outval[j] = inval[in_packed_dim_pos & 0x3];
75+ VEC4_T inval = VEC4_T(texelFetch(image_in, fetch_pos, 0 ));
76+ outval[j] = inval[in_packed_dim_lane_index];
77+
78+ // go to next position in the input, that is mapped to the packed dim in the output
79+ in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++ ;
7680 }
7781
78- pos[packed_dim] = uint16_t (gl_GlobalInvocationID[packed_dim]);
82+ pos[packed_dim] = int (gl_GlobalInvocationID[packed_dim]);
7983
8084 imageStore(image_out, pos, outval);
8185}
0 commit comments