@@ -36,13 +36,18 @@ layout(push_constant) uniform restrict Block {
36
36
ivec4 weight_sizes;
37
37
};
38
38
39
+ #include "indexing_utils.h"
40
+
39
41
layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
40
42
43
+ #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
44
+
41
45
void main() {
42
- const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS;
43
- const uint out_col = gl_GlobalInvocationID.x << 2 ;
46
+ const uint16_t out_width_ntexels = uint16_t(divup4(out_sizes.x));
47
+ const uint16_t out_col = uint16_t((gl_GlobalInvocationID.x % out_width_ntexels) << 2 );
48
+ const uint16_t out_row = uint16_t((gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS);
44
49
45
- if (out_col >= out_sizes.x || out_row >= out_sizes.y) {
50
+ if (out_row >= uint16_t( out_sizes.y) ) {
46
51
return ;
47
52
}
48
53
@@ -51,29 +56,29 @@ void main() {
51
56
VEC4_T c[TILE_ROWS];
52
57
53
58
$if SCALES_STORAGE == "buffer ":
54
- const VEC4_T scales = VEC4_T(t_scales[out_col >> 2 ]);
59
+ const VEC4_T scales = VEC4_T(t_scales[int ( out_col >> 2 ) ]);
55
60
$else :
56
- const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2 (out_col >> 2 , 0 ), 0 ));
61
+ const VEC4_T scales = VEC4_T(texelFetch(t_scales, u16vec2 (out_col >> 2 , 0 ), 0 ));
57
62
58
63
[[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
59
64
c[i] = VEC4_T(0.0 );
60
65
}
61
66
62
- for (int pos = 0 ; pos < in_sizes.x; pos += 4 ) {
67
+ for (uint16_t pos = uint16_t( 0 ) ; pos < uint16_t( in_sizes.x) ; pos += uint16_t( 4 ) ) {
63
68
// Preload weight tensor
64
69
[[unroll]] for (int i = 0 ; i < 4 ; i++ ) {
65
70
$if WEIGHT_STORAGE == "buffer ":
66
71
b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2 ];
67
72
$else :
68
- b[i] = VEC4_T(texelFetch(t_weight, ivec2 (out_col >> 2 , pos + i), 0 ));
73
+ b[i] = VEC4_T(texelFetch(t_weight, u16vec2 (out_col >> 2 , pos + i), 0 ));
69
74
}
70
75
71
76
// Preload input tensor
72
77
[[unroll]] for (int i = 0 ; i < TILE_ROWS; i++ ) {
73
78
$if IN_STORAGE == "buffer ":
74
79
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2 ];
75
80
$else :
76
- a[i] = VEC4_T(texelFetch(t_in, ivec3 (pos >> 2 , out_row + i, 0 ), 0 ));
81
+ a[i] = VEC4_T(texelFetch(t_in, u16vec3 (pos >> 2 , out_row + i, 0 ), 0 ));
77
82
}
78
83
79
84
// Accumulate output
0 commit comments