Skip to content

Commit 6daa661

Browse files
committed
Fix for inaccuracies in the scalar version
1 parent 3b3d551 commit 6daa661

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

ggml/src/ggml-cpu/repack.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -532,12 +532,12 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
532532
int offset = ((k / 2) % 2) + j * 2;
533533
for (int i = 0; i < blocklen; ++i) {
534534
const int hbits_index = k * ncols_interleaved * blocklen + j * blocklen + i;
535-
const int lbits_index = (hbits_index / 32) * 64 + (hbits_index % 32);
535+
const int lbits_index = hbits_index + (k/4) * 256;
536536

537-
int8_t v0 = (int8_t)((b_ptr[l].qh[hbits_index] & 3) << 4) | (b_ptr[l].ql[lbits_index] & 0xF) - 32;
538-
int8_t v1 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4) | (b_ptr[l].ql[lbits_index + 32] & 0xF) - 32;
539-
int8_t v2 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index] >> 4) & 0xF) - 32;
540-
int8_t v3 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index + 32] >> 4) & 0xF) - 32;
537+
int8_t v0 = (int8_t)(((b_ptr[l].qh[hbits_index] & 3) << 4) | (b_ptr[l].ql[lbits_index] & 0xF)) - 32;
538+
int8_t v1 = (int8_t)((((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4) | (b_ptr[l].ql[lbits_index + 256] & 0xF)) - 32;
539+
int8_t v2 = (int8_t)((((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index] >> 4) & 0xF)) - 32;
540+
int8_t v3 = (int8_t)((((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index + 256] >> 4) & 0xF)) - 32;
541541

542542
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
543543
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
@@ -556,6 +556,7 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
556556
}
557557
for (int j = 0; j < ncols_interleaved; j++) {
558558
s[x * ncols_interleaved + j] = sumf[j];
559+
559560
}
560561
}
561562
}
@@ -999,21 +1000,21 @@ void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
9991000
int offset = ((k / 2) % 2) + j * 2;
10001001
for (int i = 0; i < blocklen; ++i){
10011002
const int hbits_index = k * ncols_interleaved * blocklen + j * blocklen + i;
1002-
const int lbits_index = (hbits_index / 32) * 64 + (hbits_index % 32);
1003+
const int lbits_index = hbits_index + (k/4) * 256;
10031004

1004-
int8_t v0 = (int8_t)((b_ptr[l].qh[hbits_index] & 3) << 4) | (b_ptr[l].ql[lbits_index] & 0xF) - 32;
1005-
int8_t v1 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4) | (b_ptr[l].ql[lbits_index + 32] & 0xF) - 32;
1006-
int8_t v2 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index] >> 4) & 0xF) - 32;
1007-
int8_t v3 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index + 32] >> 4) & 0xF) - 32;
1005+
int8_t v0 = (int8_t)(((b_ptr[l].qh[hbits_index] & 3) << 4) | (b_ptr[l].ql[lbits_index] & 0xF)) - 32;
1006+
int8_t v1 = (int8_t)((((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4) | (b_ptr[l].ql[lbits_index + 256] & 0xF)) - 32;
1007+
int8_t v2 = (int8_t)((((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index] >> 4) & 0xF)) - 32;
1008+
int8_t v3 = (int8_t)((((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index + 256] >> 4) & 0xF)) - 32;
10081009

10091010
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
10101011
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
10111012
sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
10121013
sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
1013-
sumi1 = sumi1 * (scales_0[offset] & 0xF);
1014-
sumi2 = sumi2 * (scales_1[offset] & 0xF);
1015-
sumi3 = sumi3 * (scales_2[offset] & 0xF);
1016-
sumi4 = sumi4 * (scales_3[offset] & 0xF);
1014+
sumi1 = sumi1 * (scales_0[offset]);
1015+
sumi2 = sumi2 * (scales_1[offset]);
1016+
sumi3 = sumi3 * (scales_2[offset]);
1017+
sumi4 = sumi4 * (scales_3[offset]);
10171018
sumi += sumi1 + sumi2 + sumi3 + sumi4;
10181019
}
10191020
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];

0 commit comments

Comments
 (0)