@@ -40,12 +40,14 @@ layout(push_constant) uniform restrict Block {
4040
4141layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4242
43+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
44+
4345void main() {
44- const uint out_width_ntexels = divup4(out_sizes.x);
45- const uint out_col = ( gl_GlobalInvocationID.x % out_width_ntexels) << 2 ;
46- const uint out_row = ( gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
46+ const uint16_t out_width_ntexels = uint16_t( divup4(out_sizes.x) );
47+ const uint16_t out_col = uint16_t(( gl_GlobalInvocationID.x % out_width_ntexels) << 2 ) ;
48+ const uint16_t out_row = uint16_t(( gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS) ;
4749
48- if (out_row >= out_sizes.y) {
50+ if (out_row >= uint16_t( out_sizes.y) ) {
4951 return ;
5052 }
5153
@@ -54,29 +56,29 @@ void main() {
5456 VEC4_T c[TILE_ROWS];
5557
5658 $if SCALES_STORAGE == "buffer ":
57- const VEC4_T scales = VEC4_T(t_scales[out_col >> 2 ]);
59+ const VEC4_T scales = VEC4_T(t_scales[int ( out_col >> 2 ) ]);
5860 $else :
59- const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2 (out_col >> 2 , 0 ), 0 ));
61+ const VEC4_T scales = VEC4_T(texelFetch(t_scales, u16vec2 (out_col >> 2 , 0 ), 0 ));
6062
6163 [[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
6264 c[i] = VEC4_T(0.0 );
6365 }
6466
65- for (int pos = 0 ; pos < in_sizes.x; pos += 4 ) {
67+ for (uint16_t pos = uint16_t( 0 ) ; pos < uint16_t( in_sizes.x) ; pos += uint16_t( 4 ) ) {
6668 // Preload weight tensor
6769 [[unroll]] for (int i = 0 ; i < 4 ; i++ ) {
6870 $if WEIGHT_STORAGE == "buffer ":
6971 b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2 ];
7072 $else :
71- b[i] = VEC4_T(texelFetch(t_weight, ivec2 (out_col >> 2 , pos + i), 0 ));
73+ b[i] = VEC4_T(texelFetch(t_weight, u16vec2 (out_col >> 2 , pos + i), 0 ));
7274 }
7375
7476 // Preload input tensor
7577 [[unroll]] for (int i = 0 ; i < TILE_ROWS; i++ ) {
7678 $if IN_STORAGE == "buffer ":
7779 a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2 ];
7880 $else :
79- a[i] = VEC4_T(texelFetch(t_in, ivec3 (pos >> 2 , out_row + i, 0 ), 0 ));
81+ a[i] = VEC4_T(texelFetch(t_in, u16vec3 (pos >> 2 , out_row + i, 0 ), 0 ));
8082 }
8183
8284 // Accumulate output
0 commit comments