@@ -64,6 +64,12 @@ void main() {
6464
6565 T sums[TILE_ROWS * TILE_TXCOLS * 4 ];
6666
67+ $if QUANT_NBITS == 4 :
68+ // accumulate mat1 elements sum so -8 bias can be applied using it later.
69+ T mat1_accum[TILE_ROWS];
70+ $for r in range(TILE_ROWS):
71+ mat1_accum[${r}] = T(0.0 );
72+
6773 for (int r = 0 ; r < TILE_ROWS; ++ r) {
6874 $for c in range(TILE_TXCOLS):
6975 $for j in range(4 ):
@@ -86,6 +92,10 @@ void main() {
8692 VEC4_T mat1_vec4 = VEC4_T(texelFetch(t_in, ivec3 (txpos, out_row + i, 0 ), 0 ));
8793 $for j in range(4 ):
8894 mat1[i * 4 + ${j}] = mat1_vec4[${j}];
95+
96+ $if QUANT_NBITS == 4 :
97+ // Accumulate mat1 element sum, this will be multiplied with -8 later for converting 4 bit data to a signed number.
98+ mat1_accum[i] += mat1[i * 4 + 0 ] + mat1[i * 4 + 1 ] + mat1[i * 4 + 2 ] + mat1[i * 4 + 3 ];
8999 }
90100
91101 $if WEIGHT_STORAGE == "buffer ":
@@ -109,13 +119,13 @@ void main() {
109119 packed_weight_tex = texelFetch(
110120 t_weight, ivec2 (weight_txcol + ${c}, pos + r), 0 );
111121
112- qmat2_vec4 = ( VEC4_T(packed_weight_tex >> 4 ) - 8.0 );
122+ qmat2_vec4 = VEC4_T(packed_weight_tex >> 4 );
113123 qmat2[${c} * 4 * TILE_TXCOLS + 0 ] = qmat2_vec4.x;
114124 qmat2[${c} * 4 * TILE_TXCOLS + 1 ] = qmat2_vec4.y;
115125 qmat2[${c} * 4 * TILE_TXCOLS + 2 ] = qmat2_vec4.z;
116126 qmat2[${c} * 4 * TILE_TXCOLS + 3 ] = qmat2_vec4.w;
117127
118- qmat2_vec4 = ( VEC4_T(packed_weight_tex & 0x0F) - 8.0 );
128+ qmat2_vec4 = VEC4_T(packed_weight_tex & 0x0F);
119129 qmat2[${c} * 4 * TILE_TXCOLS + 4 ] = qmat2_vec4.x;
120130 qmat2[${c} * 4 * TILE_TXCOLS + 5 ] = qmat2_vec4.y;
121131 qmat2[${c} * 4 * TILE_TXCOLS + 6 ] = qmat2_vec4.z;
@@ -156,10 +166,16 @@ void main() {
156166 for (int r = 0 ; r < TILE_ROWS; ++ r) {
157167 VEC4_T scaled_sums;
158168 $for c in range(TILE_TXCOLS):
159- scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0 ] * scales[${c}].x;
160- scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1 ] * scales[${c}].y;
161- scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2 ] * scales[${c}].z;
162- scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3 ] * scales[${c}].w;
169+ $if QUANT_NBITS == 4 :
170+ scaled_sums.x = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0 ] + mat1_accum[r] * - 8.0 ) * scales[${c}].x;
171+ scaled_sums.y = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1 ] + mat1_accum[r] * - 8.0 ) * scales[${c}].y;
172+ scaled_sums.z = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2 ] + mat1_accum[r] * - 8.0 ) * scales[${c}].z;
173+ scaled_sums.w = (sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3 ] + mat1_accum[r] * - 8.0 ) * scales[${c}].w;
174+ $else :
175+ scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0 ] * scales[${c}].x;
176+ scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1 ] * scales[${c}].y;
177+ scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2 ] * scales[${c}].z;
178+ scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3 ] * scales[${c}].w;
163179
164180 $if OUT_STORAGE == "buffer ":
165181 if (out_row + r < out_sizes.y) {
0 commit comments