@@ -21,8 +21,6 @@ ${define_required_extensions(DTYPE)}
2121$if WEIGHT_STORAGE == "buffer ":
2222 ${define_required_extensions("int8")}
2323
24- #extension GL_EXT_control_flow_attributes : require
25-
2624layout (std430) buffer ;
2725
2826${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array= False)}
@@ -49,20 +47,18 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4947void main() {
5048 // txcol stands for "texel column". One txcol corresponds to 4 scalar columns.
5149 $if TILE_TXCOLS > 1 :
52- const uint16_t global_wg_x = uint16_t(divup(out_sizes.x, 4 * TILE_TXCOLS));
53- const uint16_t out_txcol = uint16_t(
54- (gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS);
50+ const int global_wg_x = divup(out_sizes.x, 4 * TILE_TXCOLS);
51+ const int out_txcol = (int (gl_GlobalInvocationID.x) % global_wg_x) * TILE_TXCOLS;
5552 $else :
56- const uint16_t global_wg_x = uint16_t( divup4(out_sizes.x) );
57- const uint16_t out_txcol = uint16_t (gl_GlobalInvocationID.x % global_wg_x) ;
53+ const int global_wg_x = divup4(out_sizes.x);
54+ const int out_txcol = int (gl_GlobalInvocationID.x) % global_wg_x;
5855
59- const uint16_t out_row = uint16_t(
60- (gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS);
56+ const int out_row = (int (gl_GlobalInvocationID.x) / global_wg_x) * TILE_ROWS;
6157
6258 $if QUANT_NBITS == 4 :
63- const uint16_t weight_txcol = uint16_t( out_txcol / 2 ) ;
59+ const int weight_txcol = out_txcol / 2 ;
6460
65- if (out_row >= uint16_t (out_sizes.y)) {
61+ if (out_row >= int (out_sizes.y)) {
6662 return ;
6763 }
6864
@@ -73,9 +69,9 @@ void main() {
7369 sums[r][${c}] = VEC4_T(0.0 );
7470 }
7571
76- for (uint16_t pos = uint16_t( 0 ) , txpos = uint16_t( 0 ) ;
77- pos < uint16_t( in_sizes.x) ;
78- pos += uint16_t( 4 ) , txpos += uint16_t( 1 ) ) {
72+ for (int pos = 0 , txpos = 0 ;
73+ pos < in_sizes.x;
74+ pos += 4 , txpos += 1 ) {
7975
8076 T mat1[TILE_ROWS][4 ];
8177
@@ -91,7 +87,7 @@ void main() {
9187 mat1[i][2 ] = tmp.z;
9288 mat1[i][3 ] = tmp.w;
9389 $else :
94- VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3 (txpos, out_row + i, 0 ), 0 ));
90+ VEC4_T tmp = VEC4_T(texelFetch(t_in, ivec3 (txpos, out_row + i, 0 ), 0 ));
9591 mat1[i][0 ] = tmp.x;
9692 mat1[i][1 ] = tmp.y;
9793 mat1[i][2 ] = tmp.z;
@@ -117,7 +113,7 @@ void main() {
117113 packed_weight_tex = t_weight[qmat2_bufi + ${c}]
118114 $else :
119115 packed_weight_tex = texelFetch(
120- t_weight, u16vec2 (weight_txcol + ${c}, pos + r), 0 );
116+ t_weight, ivec2 (weight_txcol + ${c}, pos + r), 0 );
121117
122118 qmat2[${c}] = (VEC4_T(packed_weight_tex >> 4 ) - 8.0 );
123119 qmat2[${c + 1 }] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0 );
@@ -128,7 +124,7 @@ void main() {
128124 qmat2[${c}] = t_weight[qmat2_bufi + ${c}];
129125 $else :
130126 qmat2[${c}] = VEC4_T(
131- texelFetch(t_weight, u16vec2 (out_txcol + ${c}, pos + r), 0 ));
127+ texelFetch(t_weight, ivec2 (out_txcol + ${c}, pos + r), 0 ));
132128
133129 for (int tr = 0 ; tr < TILE_ROWS; ++ tr) {
134130 $for c in range(TILE_TXCOLS):
@@ -143,7 +139,7 @@ void main() {
143139 scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]);
144140 $else :
145141 scales[${c}] = VEC4_T(
146- texelFetch(t_scales, u16vec2 (out_txcol + ${c}, 0 ), 0 ));
142+ texelFetch(t_scales, ivec2 (out_txcol + ${c}, 0 ), 0 ));
147143
148144 // Store to output tensor
149145 $if OUT_STORAGE == "buffer ":
0 commit comments