Skip to content

Commit bb9d978

Browse files
author
horasal
committed
Experimental support for MXFP6_E3M2
1 parent 5d195f1 commit bb9d978

File tree

18 files changed

+1776
-1171
lines changed

18 files changed

+1776
-1171
lines changed

ggml/include/ggml.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ extern "C" {
417417
// GGML_TYPE_IQ4_NL_4_8 = 37,
418418
// GGML_TYPE_IQ4_NL_8_8 = 38,
419419
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
420-
GGML_TYPE_COUNT = 40,
420+
GGML_TYPE_MXFP6_E3M2 = 40,
421+
GGML_TYPE_COUNT = 41,
421422
};
422423

423424
// precision
@@ -453,6 +454,7 @@ extern "C" {
453454
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
454455
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
455456
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
457+
GGML_FTYPE_MOSTLY_MXFP6_E3M2 = 26, // except 1d tensors
456458
};
457459

458460
// available tensor operations:

ggml/src/ggml-common.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,14 @@ typedef struct {
194194
} block_mxfp4;
195195
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
196196

197+
#define QK_MXFP6_E3M2 32
198+
typedef struct {
199+
uint8_t e; // E8M0
200+
uint8_t qs[QK_MXFP6_E3M2 * 3 / 4]; // 6bits -> 8bits
201+
} block_mxfp6_e3m2;
202+
static_assert(sizeof(block_mxfp6_e3m2) == sizeof(uint8_t) + QK_MXFP6_E3M2 * 3 / 4, "wrong mxfp6_e3m2 block size/padding");
203+
204+
197205
#define QK5_0 32
198206
typedef struct {
199207
ggml_half d; // delta
@@ -1095,6 +1103,15 @@ GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
10951103
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
10961104
GGML_TABLE_END()
10971105

1106+
GGML_TABLE_BEGIN(int16_t, kvalues_mxfp6_e3m2, 64)
1107+
0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28,
1108+
32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224,
1109+
256, 320, 384, 448,
1110+
0, -1, -2, -3, -4, -5, -6, -7, -8, -10, -12, -14, -16, -20, -24, -28,
1111+
-32, -40, -48, -56, -64, -80, -96, -112, -128, -160, -192, -224,
1112+
-256, -320, -384, -448,
1113+
GGML_TABLE_END()
1114+
10981115
#define NGRID_IQ1S 2048
10991116
#define IQ1S_DELTA 0.125f
11001117
#define IQ1M_DELTA 0.125f

ggml/src/ggml-cpu/arch-fallback.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
1515
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
1616
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
17+
#define ggml_vec_dot_mxfp6_e3m2_q8_0_generic ggml_vec_dot_mxfp6_e3m2_q8_0
1718
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
1819
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
1920
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -98,6 +99,7 @@
9899
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
99100
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
100101
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
102+
#define ggml_vec_dot_mxfp6_e3m2_q8_0_generic ggml_vec_dot_mxfp6_e3m2_q8_0
101103
// repack.cpp
102104
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
103105
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -131,6 +133,7 @@
131133
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
132134
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
133135
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
136+
#define ggml_vec_dot_mxfp6_e3m2_q8_0_generic ggml_vec_dot_mxfp6_e3m2_q8_0
134137
// repack.cpp
135138
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
136139
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -193,6 +196,7 @@
193196
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
194197
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
195198
#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
199+
#define ggml_vec_dot_mxfp6_e3m2_q8_0_generic ggml_vec_dot_mxfp6_e3m2_q8_0
196200
// repack.cpp
197201
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
198202
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8

ggml/src/ggml-cpu/arch/arm/quants.c

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,42 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
650650
*s = sumf;
651651
}
652652

653+
void ggml_vec_dot_mxfp6_e3m2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
654+
assert(nrc == 1);
655+
UNUSED(nrc);
656+
UNUSED(bx);
657+
UNUSED(by);
658+
UNUSED(bs);
659+
assert(n % QK_MXFP6_E3M2 == 0);
660+
static_assert(QK_MXFP6_E3M2 == QK8_0, "QK_MXFP6_E3M2 and QK8_0 must be the same");
661+
662+
const block_mxfp6_e3m2 * GGML_RESTRICT x = vx;
663+
const block_q8_0 * GGML_RESTRICT y = vy;
664+
665+
const int nb = n / QK_MXFP6_E3M2;
666+
667+
int ib = 0;
668+
float sumf = 0;
669+
670+
for (; ib < nb; ++ib) {
671+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
672+
int sumi1 = 0;
673+
int sumi2 = 0;
674+
int sumi3 = 0;
675+
int sumi4 = 0;
676+
// Q8_0 (y) * MXFP6 (block_size = 32)
677+
for (int j = 0; j < QK_MXFP6_E3M2/4; ++j) {
678+
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j] & 0x3f];
679+
sumi2 += y[ib].qs[j + 1 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j] >> 6) | ((x[ib].qs[3 * j + 1] & 0x0F) << 2)];
680+
sumi3 += y[ib].qs[j + 2 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j + 1] >> 4) | ((x[ib].qs[3 * j + 2] & 0x03) << 4)];
681+
sumi4 += y[ib].qs[j + 3 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j + 2] >> 2];
682+
}
683+
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
684+
}
685+
*s = sumf;
686+
}
687+
688+
653689
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
654690
const int qk = QK8_0;
655691
const int nb = n / qk;
@@ -3647,4 +3683,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
36473683
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
36483684
#endif
36493685
}
3650-

ggml/src/ggml-cpu/arch/x86/quants.c

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,41 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
842842
*s = sumf;
843843
}
844844

845+
void ggml_vec_dot_mxfp6_e3m2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
846+
assert(nrc == 1);
847+
UNUSED(nrc);
848+
UNUSED(bx);
849+
UNUSED(by);
850+
UNUSED(bs);
851+
assert(n % QK_MXFP6_E3M2 == 0);
852+
static_assert(QK_MXFP6_E3M2 == QK8_0, "QK_MXFP6_E3M2 and QK8_0 must be the same");
853+
854+
const block_mxfp6_e3m2 * GGML_RESTRICT x = vx;
855+
const block_q8_0 * GGML_RESTRICT y = vy;
856+
857+
const int nb = n / QK_MXFP6_E3M2;
858+
859+
int ib = 0;
860+
float sumf = 0;
861+
862+
for (; ib < nb; ++ib) {
863+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
864+
int sumi1 = 0;
865+
int sumi2 = 0;
866+
int sumi3 = 0;
867+
int sumi4 = 0;
868+
// Q8_0 (y) * MXFP6 (block_size = 32)
869+
for (int j = 0; j < QK_MXFP6_E3M2/4; ++j) {
870+
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j] & 0x3f];
871+
sumi2 += y[ib].qs[j + 1 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j] >> 6) | ((x[ib].qs[3 * j + 1] & 0x0F) << 2)];
872+
sumi3 += y[ib].qs[j + 2 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j + 1] >> 4) | ((x[ib].qs[3 * j + 2] & 0x03) << 4)];
873+
sumi4 += y[ib].qs[j + 3 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j + 2] >> 2];
874+
}
875+
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
876+
}
877+
*s = sumf;
878+
}
879+
845880
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
846881
const int qk = QK8_0;
847882
const int nb = n / qk;
@@ -3817,4 +3852,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
38173852
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
38183853
#endif
38193854
}
3820-

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
259259
.vec_dot_type = GGML_TYPE_Q8_0,
260260
.nrows = 1,
261261
},
262+
[GGML_TYPE_MXFP6_E3M2] = {
263+
.from_float = quantize_row_mxfp6_e3m2,
264+
.vec_dot = ggml_vec_dot_mxfp6_e3m2_q8_0,
265+
.vec_dot_type = GGML_TYPE_Q8_0,
266+
.nrows = 1,
267+
},
262268
[GGML_TYPE_Q2_K] = {
263269
.from_float = quantize_row_q2_K,
264270
.vec_dot = ggml_vec_dot_q2_K_q8_K,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,7 @@ void ggml_compute_forward_add(
668668
case GGML_TYPE_Q5_1:
669669
case GGML_TYPE_Q8_0:
670670
case GGML_TYPE_MXFP4:
671+
case GGML_TYPE_MXFP6_E3M2:
671672
case GGML_TYPE_Q2_K:
672673
case GGML_TYPE_Q3_K:
673674
case GGML_TYPE_Q4_K:
@@ -1117,6 +1118,7 @@ void ggml_compute_forward_add1(
11171118
case GGML_TYPE_Q8_0:
11181119
case GGML_TYPE_Q8_1:
11191120
case GGML_TYPE_MXFP4:
1121+
case GGML_TYPE_MXFP6_E3M2:
11201122
case GGML_TYPE_Q2_K:
11211123
case GGML_TYPE_Q3_K:
11221124
case GGML_TYPE_Q4_K:
@@ -1244,7 +1246,7 @@ void ggml_compute_forward_acc(
12441246
case GGML_TYPE_Q5_1:
12451247
case GGML_TYPE_Q8_0:
12461248
case GGML_TYPE_Q8_1:
1247-
case GGML_TYPE_MXFP4:
1249+
case GGML_TYPE_MXFP6_E3M2:
12481250
case GGML_TYPE_Q2_K:
12491251
case GGML_TYPE_Q3_K:
12501252
case GGML_TYPE_Q4_K:
@@ -4140,7 +4142,7 @@ void ggml_compute_forward_out_prod(
41404142
case GGML_TYPE_Q5_0:
41414143
case GGML_TYPE_Q5_1:
41424144
case GGML_TYPE_Q8_0:
4143-
case GGML_TYPE_MXFP4:
4145+
case GGML_TYPE_MXFP6_E3M2:
41444146
case GGML_TYPE_Q2_K:
41454147
case GGML_TYPE_Q3_K:
41464148
case GGML_TYPE_Q4_K:
@@ -4415,7 +4417,7 @@ void ggml_compute_forward_set(
44154417
case GGML_TYPE_Q5_1:
44164418
case GGML_TYPE_Q8_0:
44174419
case GGML_TYPE_Q8_1:
4418-
case GGML_TYPE_MXFP4:
4420+
case GGML_TYPE_MXFP6_E3M2:
44194421
case GGML_TYPE_Q2_K:
44204422
case GGML_TYPE_Q3_K:
44214423
case GGML_TYPE_Q4_K:
@@ -4677,7 +4679,7 @@ void ggml_compute_forward_get_rows(
46774679
case GGML_TYPE_Q5_1:
46784680
case GGML_TYPE_Q8_0:
46794681
case GGML_TYPE_Q8_1:
4680-
case GGML_TYPE_MXFP4:
4682+
case GGML_TYPE_MXFP6_E3M2:
46814683
case GGML_TYPE_Q2_K:
46824684
case GGML_TYPE_Q3_K:
46834685
case GGML_TYPE_Q4_K:
@@ -5401,7 +5403,7 @@ void ggml_compute_forward_clamp(
54015403
case GGML_TYPE_Q5_1:
54025404
case GGML_TYPE_Q8_0:
54035405
case GGML_TYPE_Q8_1:
5404-
case GGML_TYPE_MXFP4:
5406+
case GGML_TYPE_MXFP6_E3M2:
54055407
case GGML_TYPE_Q2_K:
54065408
case GGML_TYPE_Q3_K:
54075409
case GGML_TYPE_Q4_K:

ggml/src/ggml-cpu/quants.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i
5050
quantize_row_mxfp4_ref(x, y, k);
5151
}
5252

53+
void quantize_row_mxfp6_e3m2(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
54+
quantize_row_mxfp6_e3m2_ref(x, y, k);
55+
}
56+
5357
//
5458
// 2-6 bit quantization in super-blocks
5559
//
@@ -216,6 +220,41 @@ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
216220
*s = sumf;
217221
}
218222

223+
void ggml_vec_dot_mxfp6_e3m2_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
224+
assert(nrc == 1);
225+
UNUSED(nrc);
226+
UNUSED(bx);
227+
UNUSED(by);
228+
UNUSED(bs);
229+
assert(n % QK_MXFP6_E3M2 == 0);
230+
static_assert(QK_MXFP6_E3M2 == QK8_0, "QK_MXFP6_E3M2 and QK8_0 must be the same");
231+
232+
const block_mxfp6_e3m2 * GGML_RESTRICT x = vx;
233+
const block_q8_0 * GGML_RESTRICT y = vy;
234+
235+
const int nb = n / QK_MXFP6_E3M2;
236+
237+
int ib = 0;
238+
float sumf = 0;
239+
240+
for (; ib < nb; ++ib) {
241+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
242+
int sumi1 = 0;
243+
int sumi2 = 0;
244+
int sumi3 = 0;
245+
int sumi4 = 0;
246+
// Q8_0 (y) * MXFP6 (block_size = 32)
247+
for (int j = 0; j < QK_MXFP6_E3M2/4; ++j) {
248+
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j] & 0x3f];
249+
sumi2 += y[ib].qs[j + 1 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j] >> 6) | ((x[ib].qs[3 * j + 1] & 0x0F) << 2)];
250+
sumi3 += y[ib].qs[j + 2 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[(x[ib].qs[3 * j + 1] >> 4) | ((x[ib].qs[3 * j + 2] & 0x03) << 4)];
251+
sumi4 += y[ib].qs[j + 3 * QK_MXFP6_E3M2/4] * kvalues_mxfp6_e3m2[ x[ib].qs[3 * j + 2] >> 2];
252+
}
253+
sumf += d * (sumi1 + sumi2 + sumi3 + sumi4);
254+
}
255+
*s = sumf;
256+
}
257+
219258
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
220259
const int qk = QK8_0;
221260
const int nb = n / qk;

ggml/src/ggml-cpu/quants.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
2020
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
2121

2222
void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
23+
void quantize_row_mxfp6_e3m2(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
2324

2425
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
2526
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -42,6 +43,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
4243
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
4344

4445
void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
46+
void ggml_vec_dot_mxfp6_e3m2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
4547

4648
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
4749
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);

0 commit comments

Comments
 (0)