@@ -36,13 +36,18 @@ layout(push_constant) uniform restrict Block {
3636 ivec4 weight_sizes;
3737};
3838
39+ #include "indexing_utils.h"
40+
3941layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4042
43+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
44+
4145void main() {
42- const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
43- const uint out_col = gl_GlobalInvocationID.x << 2 ;
46+ const uint16_t out_width_ntexels = uint16_t(divup4(out_sizes.x));
47+ const uint16_t out_col = uint16_t((gl_GlobalInvocationID.x % out_width_ntexels) << 2 );
48+ const uint16_t out_row = uint16_t((gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS);
4449
45- if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
50+ if (out_row >= uint16_t( out_sizes.y) ) {
4651 return ;
4752 }
4853
@@ -51,29 +56,29 @@ void main() {
5156 VEC4_T c[TILE_ROWS];
5257
5358 $if SCALES_STORAGE == "buffer ":
54- const VEC4_T scales = VEC4_T(t_scales[out_col >> 2 ]);
59+ const VEC4_T scales = VEC4_T(t_scales[int ( out_col >> 2 ) ]);
5560 $else :
56- const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2 (out_col >> 2 , 0 ), 0 ));
61+ const VEC4_T scales = VEC4_T(texelFetch(t_scales, u16vec2 (out_col >> 2 , 0 ), 0 ));
5762
5863 [[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
5964 c[i] = VEC4_T(0.0 );
6065 }
6166
62- for (int pos = 0 ; pos < in_sizes.x; pos += 4 ) {
67+ for (uint16_t pos = uint16_t( 0 ) ; pos < uint16_t( in_sizes.x) ; pos += uint16_t( 4 ) ) {
6368 // Preload weight tensor
6469 [[unroll]] for (int i = 0 ; i < 4 ; i++ ) {
6570 $if WEIGHT_STORAGE == "buffer ":
6671 b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2 ];
6772 $else :
68- b[i] = VEC4_T(texelFetch(t_weight, ivec2 (out_col >> 2 , pos + i), 0 ));
73+ b[i] = VEC4_T(texelFetch(t_weight, u16vec2 (out_col >> 2 , pos + i), 0 ));
6974 }
7075
7176 // Preload input tensor
7277 [[unroll]] for (int i = 0 ; i < TILE_ROWS; i++ ) {
7378 $if IN_STORAGE == "buffer ":
7479 a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2 ];
7580 $else :
76- a[i] = VEC4_T(texelFetch(t_in, ivec3 (pos >> 2 , out_row + i, 0 ), 0 ));
81+ a[i] = VEC4_T(texelFetch(t_in, u16vec3 (pos >> 2 , out_row + i, 0 ), 0 ));
7782 }
7883
7984 // Accumulate output
0 commit comments