Skip to content

Commit 09a2b11

Browse files
author
horasal
committed
FIX quantization
1 parent fb28732 commit 09a2b11

File tree

4 files changed

+160
-88
lines changed

4 files changed

+160
-88
lines changed

ggml/src/ggml-cpu/arch/arm/quants.c

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -650,42 +650,6 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
650650
*s = sumf;
651651
}
652652

653-
void ggml_vec_dot_mxfp6_e3m2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
654-
assert(nrc == 1);
655-
UNUSED(nrc);
656-
UNUSED(bx);
657-
UNUSED(by);
658-
UNUSED(bs);
659-
assert(n % QK_MXFP6_E3M2 == 0);
660-
static_assert(QK_MXFP6_E3M2 == QK8_0, "QK_MXFP6_E3M2 and QK8_0 must be the same");
661-
662-
const block_mxfp6_e3m2 * GGML_RESTRICT x = vx;
663-
const block_q8_0 * GGML_RESTRICT y = vy;
664-
665-
const int nb = n / QK_MXFP6_E3M2;
666-
667-
int ib = 0;
668-
float sumf = 0;
669-
670-
for (; ib < nb; ++ib) {
671-
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
672-
int sumi1 = 0;
673-
int sumi2 = 0;
674-
int sumi3 = 0;
675-
int sumi4 = 0;
676-
// Q8_0 (y) * MXFP6 (block_size = 32)
677-
for (int j = 0; j < QK_MXFP6_E3M2/4; ++j) {
678-
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j] & 0x3f];
679-
sumi2 += y[ib].qs[j + 1 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j] >> 6) | ((x[ib].qs[3 * j + 1] & 0x0F) << 2)];
680-
sumi3 += y[ib].qs[j + 2 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j + 1] >> 4) | ((x[ib].qs[3 * j + 2] & 0x03) << 4)];
681-
sumi4 += y[ib].qs[j + 3 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j + 2] >> 2];
682-
}
683-
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
684-
}
685-
*s = sumf;
686-
}
687-
688-
689653
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
690654
const int qk = QK8_0;
691655
const int nb = n / qk;

ggml/src/ggml-cpu/arch/x86/quants.c

Lines changed: 118 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -844,37 +844,129 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
844844

845845
void ggml_vec_dot_mxfp6_e3m2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
846846
assert(nrc == 1);
847-
UNUSED(nrc);
848-
UNUSED(bx);
849-
UNUSED(by);
850-
UNUSED(bs);
851-
assert(n % QK_MXFP6_E3M2 == 0);
852-
static_assert(QK_MXFP6_E3M2 == QK8_0, "QK_MXFP6_E3M2 and QK8_0 must be the same");
847+
UNUSED(nrc);
848+
UNUSED(bx);
849+
UNUSED(by);
850+
UNUSED(bs);
851+
assert(n % QK_MXFP6_E3M2 == 0);
852+
static_assert(QK_MXFP6_E3M2 == QK8_0, "QK_MXFP6_E3M2 and QK8_0 must be the same");
853+
assert(QK_MXFP6_E3M2 == 32);
854+
855+
const block_mxfp6_e3m2 * GGML_RESTRICT x = vx;
856+
const block_q8_0 * GGML_RESTRICT y = vy;
857+
858+
const int nb = n / QK_MXFP6_E3M2;
859+
860+
int ib = 0;
861+
float sumf = 0;
862+
863+
#if defined __AVX2__
864+
__m256 accum_ps = _mm256_setzero_ps();
865+
866+
for (; ib + 1 < nb; ib += 2) {
867+
const block_mxfp6_e3m2 * x1 = &x[ib + 0];
868+
const block_q8_0 * y1 = &y[ib + 0];
869+
870+
const block_mxfp6_e3m2 * x2 = &x[ib + 1];
871+
const block_q8_0 * y2 = &y[ib + 1];
872+
873+
alignas(32) int16_t k_vals_1[32];
874+
{
875+
const uint8_t * q3 = x1->qs;
876+
for (int j = 0; j < 8; ++j) {
877+
const uint8_t b0 = q3[0];
878+
const uint8_t b1 = q3[1];
879+
const uint8_t b2 = q3[2];
880+
k_vals_1[4*j + 0] = kvalues_mxfp6_e3m2[b0 & 0x3F];
881+
k_vals_1[4*j + 1] = kvalues_mxfp6_e3m2[(b0 >> 6) | ((b1 & 0x0F) << 2)];
882+
k_vals_1[4*j + 2] = kvalues_mxfp6_e3m2[(b1 >> 4) | ((b2 & 0x03) << 4)];
883+
k_vals_1[4*j + 3] = kvalues_mxfp6_e3m2[b2 >> 2];
884+
q3 += 3;
885+
}
886+
}
887+
888+
alignas(32) int16_t k_vals_2[32];
889+
{
890+
const uint8_t * q3 = x2->qs;
891+
for (int j = 0; j < 8; ++j) {
892+
const uint8_t b0 = q3[0];
893+
const uint8_t b1 = q3[1];
894+
const uint8_t b2 = q3[2];
895+
k_vals_2[4*j + 0] = kvalues_mxfp6_e3m2[b0 & 0x3F];
896+
k_vals_2[4*j + 1] = kvalues_mxfp6_e3m2[(b0 >> 6) | ((b1 & 0x0F) << 2)];
897+
k_vals_2[4*j + 2] = kvalues_mxfp6_e3m2[(b1 >> 4) | ((b2 & 0x03) << 4)];
898+
k_vals_2[4*j + 3] = kvalues_mxfp6_e3m2[b2 >> 2];
899+
q3 += 3;
900+
}
901+
}
902+
903+
const __m256i k_1_lo = _mm256_load_si256((const __m256i *)(k_vals_1 + 0)); // k-vals 0-15
904+
const __m256i k_1_hi = _mm256_load_si256((const __m256i *)(k_vals_1 + 16)); // k-vals 16-31
905+
906+
const __m256i q8_1_all = _mm256_loadu_si256((const __m256i *)y1->qs);
907+
908+
const __m256i q8_1_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8_1_all, 0)); // q-vals 0-15
909+
const __m256i q8_1_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8_1_all, 1)); // q-vals 16-31
910+
911+
const __m256i p_1_lo = _mm256_madd_epi16(k_1_lo, q8_1_lo);
912+
const __m256i p_1_hi = _mm256_madd_epi16(k_1_hi, q8_1_hi);
913+
914+
const __m256i p_1_all = _mm256_add_epi32(p_1_lo, p_1_hi); // 8x s32
915+
916+
const __m256i k_2_lo = _mm256_load_si256((const __m256i *)(k_vals_2 + 0));
917+
const __m256i k_2_hi = _mm256_load_si256((const __m256i *)(k_vals_2 + 16));
918+
const __m256i q8_2_all = _mm256_loadu_si256((const __m256i *)y2->qs);
919+
const __m256i q8_2_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8_2_all, 0));
920+
const __m256i q8_2_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8_2_all, 1));
921+
const __m256i p_2_lo = _mm256_madd_epi16(k_2_lo, q8_2_lo);
922+
const __m256i p_2_hi = _mm256_madd_epi16(k_2_hi, q8_2_hi);
923+
const __m256i p_2_all = _mm256_add_epi32(p_2_lo, p_2_hi); // 8x s32
924+
925+
const __m256 p_1_ps = _mm256_cvtepi32_ps(p_1_all);
926+
const __m256 p_2_ps = _mm256_cvtepi32_ps(p_2_all);
927+
928+
// (d = d_y * d_x)
929+
const float d1 = GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e);
930+
const float d2 = GGML_CPU_FP16_TO_FP32(y2->d) * GGML_E8M0_TO_FP32_HALF(x2->e);
931+
932+
const __m256 d_1_ps = _mm256_set1_ps(d1);
933+
const __m256 d_2_ps = _mm256_set1_ps(d2);
934+
935+
// Fused Multiply-Add (FMA): accum = (d * p) + accum
936+
accum_ps = _mm256_fmadd_ps(d_1_ps, p_1_ps, accum_ps);
937+
accum_ps = _mm256_fmadd_ps(d_2_ps, p_2_ps, accum_ps);
938+
}
853939

854-
const block_mxfp6_e3m2 * GGML_RESTRICT x = vx;
855-
const block_q8_0 * GGML_RESTRICT y = vy;
940+
sumf = hsum_float_8(accum_ps);
941+
#endif
856942

857-
const int nb = n / QK_MXFP6_E3M2;
943+
for (; ib < nb; ++ib) {
944+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d) * GGML_E8M0_TO_FP32_HALF(x[ib].e);
858945

859-
int ib = 0;
860-
float sumf = 0;
946+
int sumi = 0;
861947

862-
for (; ib < nb; ++ib) {
863-
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
864-
int sumi1 = 0;
865-
int sumi2 = 0;
866-
int sumi3 = 0;
867-
int sumi4 = 0;
868-
// Q8_0 (y) * MXFP6 (block_size = 32)
869-
for (int j = 0; j < QK_MXFP6_E3M2/4; ++j) {
870-
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j] & 0x3f];
871-
sumi2 += y[ib].qs[j + 1 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j] >> 6) | ((x[ib].qs[3 * j + 1] & 0x0F) << 2)];
872-
sumi3 += y[ib].qs[j + 2 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j + 1] >> 4) | ((x[ib].qs[3 * j + 2] & 0x03) << 4)];
873-
sumi4 += y[ib].qs[j + 3 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j + 2] >> 2];
948+
for (int j = 0; j < QK_MXFP6_E3M2 / 4; ++j) {
949+
const uint8_t * q3 = x[ib].qs + 3 * j;
950+
const int8_t * q8 = y[ib].qs + 4 * j;
951+
952+
const uint8_t b0 = q3[0];
953+
const uint8_t b1 = q3[1];
954+
const uint8_t b2 = q3[2];
955+
956+
const uint8_t v0_idx = b0 & 0x3F;
957+
const uint8_t v1_idx = (b0 >> 6) | ((b1 & 0x0F) << 2);
958+
const uint8_t v2_idx = (b1 >> 4) | ((b2 & 0x03) << 4);
959+
const uint8_t v3_idx = b2 >> 2;
960+
961+
sumi += q8[0] * kvalues_mxfp6_e3m2[v0_idx];
962+
sumi += q8[1] * kvalues_mxfp6_e3m2[v1_idx];
963+
sumi += q8[2] * kvalues_mxfp6_e3m2[v2_idx];
964+
sumi += q8[3] * kvalues_mxfp6_e3m2[v3_idx];
965+
}
966+
sumf += d * sumi;
874967
}
875-
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
876-
}
877-
*s = sumf;
968+
969+
*s = sumf;
878970
}
879971

880972
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {

ggml/src/ggml-cpu/quants.c

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
220220
*s = sumf;
221221
}
222222

223-
void ggml_vec_dot_mxfp6_e3m2_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
223+
void ggml_vec_dot_mxfp6_e3m2_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc)
224+
{
224225
assert(nrc == 1);
225226
UNUSED(nrc);
226227
UNUSED(bx);
@@ -240,18 +241,34 @@ void ggml_vec_dot_mxfp6_e3m2_q8_0_generic(int n, float * GGML_RESTRICT s, size_t
240241
for (; ib < nb; ++ib) {
241242
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
242243
int sumi1 = 0;
243-
int sumi2 = 0;
244-
int sumi3 = 0;
245-
int sumi4 = 0;
246-
// Q8_0 (y) * MXFP6 (block_size = 32)
247-
for (int j = 0; j < QK_MXFP6_E3M2/4; ++j) {
248-
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j] & 0x3f];
249-
sumi2 += y[ib].qs[j + 1 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j] >> 6) | ((x[ib].qs[3 * j + 1] & 0x0F) << 2)];
250-
sumi3 += y[ib].qs[j + 2 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j + 1] >> 4) | ((x[ib].qs[3 * j + 2] & 0x03) << 4)];
251-
sumi4 += y[ib].qs[j + 3 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j + 2] >> 2];
252-
}
253-
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
244+
// Q8_0 (y) * MXFP6 (block_size = 32)
245+
for (int j = 0; j < QK_MXFP6_E3M2/4; ++j) {
246+
// Current Packed MXFP6
247+
const uint8_t* q3 = x[ib].qs + 3 * j;
248+
// Current Packed Q8_0
249+
const int8_t* q8 = y[ib].qs + 4 * j;
250+
251+
const uint8_t b0 = q3[0];
252+
const uint8_t b1 = q3[1];
253+
const uint8_t b2 = q3[2];
254+
255+
const uint8_t v0_idx = b0 & 0x3F;
256+
const uint8_t v0_idx = b0 & 0x3F;
257+
const uint8_t v1_idx = (b0 >> 6) | ((b1 & 0x0F) << 2);
258+
const uint8_t v2_idx = (b1 >> 4) | ((b2 & 0x03) << 4);
259+
const uint8_t v3_idx = b2 >> 2;
260+
261+
// (y[4*j + 0] * x[4*j + 0])
262+
sumi += q8[0] * kvalues_mxfp6_e3m2[v0_idx];
263+
// (y[4*j + 1] * x[4*j + 1])
264+
sumi += q8[1] * kvalues_mxfp6_e3m2[v1_idx];
265+
// (y[4*j + 2] * x[4*j + 2])
266+
sumi += q8[2] * kvalues_mxfp6_e3m2[v2_idx];
267+
// (y[4*j + 3] * x[4*j + 3])
268+
sumi += q8[3] * kvalues_mxfp6_e3m2[v3_idx];
254269
}
270+
sumf += d * sumi;
271+
}
255272
*s = sumf;
256273
}
257274

ggml/src/ggml-quants.c

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,11 @@ void quantize_row_mxfp6_e3m2_ref(const float * GGML_RESTRICT x, block_mxfp6_e3m2
308308
y[i].e = e;
309309

310310
// 4 * 6bit quant -> 3 bytes
311-
// Final Size: qk * 3 / 4
312311
for (int j = 0; j < qk / 4; ++j) {
313-
const uint8_t x0 = best_index_mxfp6_e3m2(x[i*qk + 0 + j], d);
314-
const uint8_t x1 = best_index_mxfp6_e3m2(x[i*qk + 1 * qk / 4 + j], d);
315-
const uint8_t x2 = best_index_mxfp6_e3m2(x[i*qk + 2 * qk / 4 + j], d);
316-
const uint8_t x3 = best_index_mxfp6_e3m2(x[i*qk + 3 * qk / 4 + j], d);
312+
const uint8_t x0 = best_index_mxfp6_e3m2(x[i*qk + 4*j + 0], d);
313+
const uint8_t x1 = best_index_mxfp6_e3m2(x[i*qk + 4*j + 1], d);
314+
const uint8_t x2 = best_index_mxfp6_e3m2(x[i*qk + 4*j + 2], d);
315+
const uint8_t x3 = best_index_mxfp6_e3m2(x[i*qk + 4*j + 3], d);
317316

318317
// 1100 0000
319318
y[i].qs[3*j] = x0 | ((x1 & 0x03) << 6);
@@ -500,15 +499,15 @@ void dequantize_row_mxfp6_e3m2(const block_mxfp6_e3m2 * GGML_RESTRICT x, float *
500499
const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
501500

502501
for (int j = 0; j < qk / 4; ++j) {
503-
const int8_t x0 = kvalues_mxfp6_e3m2[x[i].qs[3 * j] & 0x3F];
504-
const int8_t x1 = kvalues_mxfp6_e3m2[(x[i].qs[3 * j] >> 6) | ((x[i].qs[3 * j + 1] & 0x0F) << 2)];
505-
const int8_t x2 = kvalues_mxfp6_e3m2[(x[i].qs[3 * j + 1] >> 4) | ((x[i].qs[3 * j + 2] & 0x03) << 4)];
506-
const int8_t x3 = kvalues_mxfp6_e3m2[x[i].qs[3 * j + 2] >> 2];
507-
508-
y[i*qk + j + 0 ] = x0*d;
509-
y[i*qk + j + 1 * qk/4] = x1*d;
510-
y[i*qk + j + 2 * qk/4] = x2*d;
511-
y[i*qk + j + 3 * qk/4] = x3*d;
502+
const int16_t x0 = kvalues_mxfp6_e3m2[x[i].qs[3 * j] & 0x3F];
503+
const int16_t x1 = kvalues_mxfp6_e3m2[(x[i].qs[3 * j] >> 6) | ((x[i].qs[3 * j + 1] & 0x0F) << 2)];
504+
const int16_t x2 = kvalues_mxfp6_e3m2[(x[i].qs[3 * j + 1] >> 4) | ((x[i].qs[3 * j + 2] & 0x03) << 4)];
505+
const int16_t x3 = kvalues_mxfp6_e3m2[x[i].qs[3 * j + 2] >> 2];
506+
507+
y[i*qk + 4 * j + 0] = x0*d;
508+
y[i*qk + 4 * j + 1] = x1*d;
509+
y[i*qk + 4 * j + 2] = x2*d;
510+
y[i*qk + 4 * j + 3] = x3*d;
512511
}
513512
}
514513
}

0 commit comments

Comments
 (0)