@@ -66,7 +66,7 @@ void main() {
6666 return ;
6767 }
6868
69- VEC4_T mat1[TILE_ROWS];
69+ T mat1[TILE_ROWS][ 4 ];
7070 VEC4_T qmat2[4 ][TILE_TXCOLS];
7171 VEC4_T sums[TILE_ROWS][TILE_TXCOLS];
7272
@@ -78,7 +78,7 @@ void main() {
7878 scales[${c}] = VEC4_T(
7979 texelFetch(t_scales, u16vec2(out_txcol + ${c}, 0 ), 0 ));
8080
81- [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
81+ for (int r = 0 ; r < TILE_ROWS; ++ r) {
8282 $for c in range(TILE_TXCOLS):
8383 sums[r][${c}] = VEC4_T(0.0 );
8484 }
@@ -91,7 +91,7 @@ void main() {
9191 uint weight_row_txstride = div4(weight_sizes.x);
9292
9393 // Preload weight tensor
94- [[unroll]] for (int r = 0 ; r < 4 ; r++ ) {
94+ for (int r = 0 ; r < 4 ; r++ ) {
9595 $if QUANT_NBITS == 4 :
9696 $for c in range(0 , TILE_TXCOLS, 2 ):
9797 $if WEIGHT_STORAGE == "buffer ":
@@ -117,21 +117,28 @@ void main() {
117117 uint in_row_txstride = div4(in_sizes.x);
118118
119119 // Preload input tensor
120- [[unroll]] for (int i = 0 ; i < TILE_ROWS; i++ ) {
120+ for (int i = 0 ; i < TILE_ROWS; i++ ) {
121121 $if IN_STORAGE == "buffer ":
122- mat1[i] = t_in[(out_row + i) * in_row_txstride + txpos];
122+ VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos];
123+ mat1[i][0 ] = tmp.x;
124+ mat1[i][1 ] = tmp.y;
125+ mat1[i][2 ] = tmp.z;
126+ mat1[i][3 ] = tmp.w;
123127 $else :
124- mat1[i] = VEC4_T(
125- texelFetch(t_in, u16vec3(txpos, out_row + i, 0 ), 0 ));
128+ VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0 ), 0 ));
129+ mat1[i][0 ] = tmp.x;
130+ mat1[i][1 ] = tmp.y;
131+ mat1[i][2 ] = tmp.z;
132+ mat1[i][3 ] = tmp.w;
126133 }
127134
128135 // Accumulate output
129- [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
136+ for (int r = 0 ; r < TILE_ROWS; ++ r) {
130137 $for c in range(TILE_TXCOLS):
131- sums[r][${c}] += mat1[r].x * qmat2[0 ][${c}] +
132- mat1[r].y * qmat2[1 ][${c}] +
133- mat1[r].z * qmat2[2 ][${c}] +
134- mat1[r].w * qmat2[3 ][${c}];
138+ sums[r][${c}] += mat1[r][ 0 ] * qmat2[0 ][${c}] +
139+ mat1[r][ 1 ] * qmat2[1 ][${c}] +
140+ mat1[r][ 2 ] * qmat2[2 ][${c}] +
141+ mat1[r][ 3 ] * qmat2[3 ][${c}];
135142 }
136143 }
137144
@@ -140,7 +147,7 @@ void main() {
140147 uint out_bufi;
141148 uint out_row_txstride = div4(out_sizes.x);
142149
143- [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
150+ for (int r = 0 ; r < TILE_ROWS; ++ r) {
144151 $for c in range(TILE_TXCOLS):
145152 $if OUT_STORAGE == "buffer ":
146153 if (out_row + r < out_sizes.y) {
0 commit comments