@@ -66,8 +66,6 @@ void main() {
6666 return ;
6767 }
6868
69- T mat1[TILE_ROWS][4 ];
70- VEC4_T qmat2[4 ][TILE_TXCOLS];
7169 VEC4_T sums[TILE_ROWS][TILE_TXCOLS];
7270
7371 VEC4_T scales[TILE_TXCOLS];
@@ -86,12 +84,35 @@ void main() {
8684 for (uint16_t pos = uint16_t(0 ), txpos = uint16_t(0 );
8785 pos < uint16_t(in_sizes.x);
8886 pos += uint16_t(4 ), txpos += uint16_t(1 )) {
87+
88+ T mat1[TILE_ROWS][4 ];
89+
90+ $if IN_STORAGE == "buffer ":
91+ uint in_row_txstride = div4(in_sizes.x);
92+
93+ // Preload input tensor
94+ for (int i = 0 ; i < TILE_ROWS; i++ ) {
95+ $if IN_STORAGE == "buffer ":
96+ VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos];
97+ mat1[i][0 ] = tmp.x;
98+ mat1[i][1 ] = tmp.y;
99+ mat1[i][2 ] = tmp.z;
100+ mat1[i][3 ] = tmp.w;
101+ $else :
102+ VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0 ), 0 ));
103+ mat1[i][0 ] = tmp.x;
104+ mat1[i][1 ] = tmp.y;
105+ mat1[i][2 ] = tmp.z;
106+ mat1[i][3 ] = tmp.w;
107+ }
108+
89109 $if WEIGHT_STORAGE == "buffer ":
90110 uint qmat2_bufi;
91111 uint weight_row_txstride = div4(weight_sizes.x);
92112
93113 // Preload weight tensor
94114 for (int r = 0 ; r < 4 ; r++ ) {
115+ VEC4_T qmat2[TILE_TXCOLS];
95116 $if QUANT_NBITS == 4 :
96117 $for c in range(0 , TILE_TXCOLS, 2 ):
97118 $if WEIGHT_STORAGE == "buffer ":
@@ -101,44 +122,21 @@ void main() {
101122 const uvec4 packed_weight_tex = texelFetch(
102123 t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0 );
103124
104- qmat2[r][ ${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4 ) - 8.0 );
105- qmat2[r][ ${c + 1 }] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0 );
125+ qmat2[${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4 ) - 8.0 );
126+ qmat2[${c + 1 }] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0 );
106127 $else :
107128 $for c in range(TILE_TXCOLS):
108129 $if WEIGHT_STORAGE == "buffer ":
109130 qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
110- qmat2[r][ ${c}] = t_weight[qmat2_bufi + ${c}];
131+ qmat2[${c}] = t_weight[qmat2_bufi + ${c}];
111132 $else :
112- qmat2[r][ ${c}] = VEC4_T(
133+ qmat2[${c}] = VEC4_T(
113134 texelFetch(t_weight, u16vec2(out_txcol + ${c}, pos + r), 0 ));
114- }
115-
116- $if IN_STORAGE == "buffer ":
117- uint in_row_txstride = div4(in_sizes.x);
118135
119- // Preload input tensor
120- for (int i = 0 ; i < TILE_ROWS; i++ ) {
121- $if IN_STORAGE == "buffer ":
122- VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos];
123- mat1[i][0 ] = tmp.x;
124- mat1[i][1 ] = tmp.y;
125- mat1[i][2 ] = tmp.z;
126- mat1[i][3 ] = tmp.w;
127- $else :
128- VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0 ), 0 ));
129- mat1[i][0 ] = tmp.x;
130- mat1[i][1 ] = tmp.y;
131- mat1[i][2 ] = tmp.z;
132- mat1[i][3 ] = tmp.w;
133- }
134-
135- // Accumulate output
136- for (int r = 0 ; r < TILE_ROWS; ++ r) {
137- $for c in range(TILE_TXCOLS):
138- sums[r][${c}] += mat1[r][0 ] * qmat2[0 ][${c}] +
139- mat1[r][1 ] * qmat2[1 ][${c}] +
140- mat1[r][2 ] * qmat2[2 ][${c}] +
141- mat1[r][3 ] * qmat2[3 ][${c}];
136+ for (int tr = 0 ; tr < TILE_ROWS; ++ tr) {
137+ $for c in range(TILE_TXCOLS):
138+ sums[tr][${c}] += qmat2[${c}] * mat1[tr][r];
139+ }
142140 }
143141 }
144142
0 commit comments