Skip to content

Commit 37e43ac

Browse files
author
horasal
committed
implement E8M0_TO_FP32_ANY for better e
1 parent b175e6a commit 37e43ac

File tree

6 files changed

+87
-68
lines changed

6 files changed

+87
-68
lines changed

ggml/src/ggml-cpu/arch/x86/quants.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ void ggml_vec_dot_mxfp6_e3m2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, con
861861
const int16_t* kvalues = (const int16_t*)kvalues_mxfp6_e3m2;
862862

863863
for (int i = 0; i < nb; ++i) {
864-
const float d = GGML_CPU_FP16_TO_FP32(y[i].d) * GGML_E8M0_TO_FP32_HALF(x[i].e);
864+
const float d = GGML_CPU_FP16_TO_FP32(y[i].d) * GGML_E8M0_TO_FP32_ANY(x[i].e, 4);
865865

866866
const __m256i q8_v = _mm256_loadu_si256((const __m256i*)y[i].qs);
867867

@@ -932,7 +932,7 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, con
932932
const int16_t* kvalues = (const int16_t*)kvalues_mxfp6_e2m3;
933933

934934
for (int i = 0; i < nb; ++i) {
935-
const float d = GGML_CPU_FP16_TO_FP32(y[i].d) * GGML_E8M0_TO_FP32_HALF(x[i].e);
935+
const float d = GGML_CPU_FP16_TO_FP32(y[i].d) * GGML_E8M0_TO_FP32_ANY(x[i].e, 3);
936936

937937
const __m256i q8_v = _mm256_loadu_si256((const __m256i*)y[i].qs);
938938

ggml/src/ggml-cpu/quants.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ void ggml_vec_dot_mxfp6_e3m2_q8_0_generic(int n, float * GGML_RESTRICT s, size_t
243243
float sumf = 0;
244244

245245
for (; ib < nb; ++ib) {
246-
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
246+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_ANY(x[ib].e, 4);
247247
int sumi = 0;
248248
// Q8_0 (y) * MXFP6 (block_size = 32)
249249
for (int j = 0; j < QK_MXFP6_E3M2/4; ++j) {
@@ -294,7 +294,7 @@ void ggml_vec_dot_mxfp6_e2m3_q8_0_generic(int n, float * GGML_RESTRICT s, size_t
294294
float sumf = 0;
295295

296296
for (; ib < nb; ++ib) {
297-
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
297+
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_ANY(x[ib].e, 3);
298298
int sumi = 0;
299299
// Q8_0 (y) * MXFP6 (block_size = 32)
300300
for (int j = 0; j < QK_MXFP6_E2M3/4; ++j) {

ggml/src/ggml-cuda/convert.cu

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -513,10 +513,10 @@ static __global__ void dequantize_block_mxfp6_e3m2(const void * __restrict__ vx,
513513
const uint8_t v2_idx = (b1 >> 4) | ((b2 & 0x03) << 4);
514514
const uint8_t v3_idx = b2 >> 2;
515515

516-
y[y_offset + 0] = d * kvalues_mxfp6_e3m2[v0_idx]*0.0625f;
517-
y[y_offset + 1] = d * kvalues_mxfp6_e3m2[v1_idx]*0.0625f;
518-
y[y_offset + 2] = d * kvalues_mxfp6_e3m2[v2_idx]*0.0625f;
519-
y[y_offset + 3] = d * kvalues_mxfp6_e3m2[v3_idx]*0.0625f;
516+
y[y_offset + 0] = d * kvalues_mxfp6_e3m2[v0_idx]*MXFP6_E3M2_SCALER;
517+
y[y_offset + 1] = d * kvalues_mxfp6_e3m2[v1_idx]*MXFP6_E3M2_SCALER;
518+
y[y_offset + 2] = d * kvalues_mxfp6_e3m2[v2_idx]*MXFP6_E3M2_SCALER;
519+
y[y_offset + 3] = d * kvalues_mxfp6_e3m2[v3_idx]*MXFP6_E3M2_SCALER;
520520
}
521521
}
522522

@@ -552,11 +552,10 @@ static __global__ void dequantize_block_mxfp6_e2m3(const void * __restrict__ vx,
552552
const uint8_t v2_idx = (b1 >> 4) | ((b2 & 0x03) << 4);
553553
const uint8_t v3_idx = b2 >> 2;
554554

555-
// Is this correct?
556-
y[y_offset + 0] = d * kvalues_mxfp6_e2m3[v0_idx]*0.0625f;
557-
y[y_offset + 1] = d * kvalues_mxfp6_e2m3[v1_idx]*0.0625f;
558-
y[y_offset + 2] = d * kvalues_mxfp6_e2m3[v2_idx]*0.0625f;
559-
y[y_offset + 3] = d * kvalues_mxfp6_e2m3[v3_idx]*0.0625f;
555+
y[y_offset + 0] = d * kvalues_mxfp6_e2m3[v0_idx]*MXFP6_E2M3_SCALER;
556+
y[y_offset + 1] = d * kvalues_mxfp6_e2m3[v1_idx]*MXFP6_E2M3_SCALER;
557+
y[y_offset + 2] = d * kvalues_mxfp6_e2m3[v2_idx]*MXFP6_E2M3_SCALER;
558+
y[y_offset + 3] = d * kvalues_mxfp6_e2m3[v3_idx]*MXFP6_E2M3_SCALER;
560559
}
561560
}
562561

ggml/src/ggml-impl.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,30 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
488488
return result;
489489
}
490490

491+
// Equal to ggml_e8m0_to_fp32/(2^e)
492+
// Useful with MXFP6 quantization since KValues are stored in different scale
493+
static inline float ggml_e8m0_to_fp32_any(uint8_t x, const uint8_t e) {
494+
uint32_t bits;
495+
const int cutoff = e + 1;
496+
if (x < cutoff) {
497+
// x=0: 0x00040000 = 2^(-131)
498+
// x=1: 0x00080000 = 2^(-130)
499+
// ...
500+
bits = (uint32_t)1 << (x - e + 22);
501+
}
502+
else {
503+
// E = x - e
504+
bits = (uint32_t)(x - e) << 23;
505+
}
506+
507+
float result;
508+
memcpy(&result, &bits, sizeof(float));
509+
return result;
510+
}
511+
491512
#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
492513
#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
514+
#define GGML_E8M0_TO_FP32_ANY(x,e) ggml_e8m0_to_fp32_any(x,e)
493515

494516
/**
495517
* Converts brain16 to float32.

ggml/src/ggml-quants.c

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,40 @@ static inline int best_index_mxfp4(float x, float e) {
270270
return best_index;
271271
}
272272

273+
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
274+
static const int qk = QK_MXFP4;
275+
276+
assert(k % qk == 0);
277+
278+
const int nb = k / qk;
279+
280+
for (int i = 0; i < nb; i++) {
281+
float amax = 0.0f; // absolute max
282+
283+
for (int j = 0; j < qk; j++) {
284+
const float v = x[i*qk + j];
285+
286+
if (amax < fabsf(v)) {
287+
amax = fabsf(v);
288+
}
289+
}
290+
291+
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
292+
293+
const float d = GGML_E8M0_TO_FP32_HALF(e);
294+
295+
y[i].e = e;
296+
297+
for (int j = 0; j < qk/2; ++j) {
298+
const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
299+
const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
300+
301+
y[i].qs[j] = x0;
302+
y[i].qs[j] |= x1 << 4;
303+
}
304+
}
305+
}
306+
273307
static inline int best_index_mxfp6_e3m2(float x, float e) {
274308
int best_index = 0;
275309
float best_err = fabsf(kvalues_mxfp6_e3m2[0]*e - x);
@@ -301,9 +335,9 @@ void quantize_row_mxfp6_e3m2_ref(const float * GGML_RESTRICT x, block_mxfp6_e3m2
301335
}
302336
}
303337

304-
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 4 + 127) : 0;
338+
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 3 + 127) : 0;
305339

306-
const float d = GGML_E8M0_TO_FP32_HALF(e);
340+
const float d = GGML_E8M0_TO_FP32_ANY(e, 4);
307341

308342
y[i].e = e;
309343

@@ -355,9 +389,9 @@ void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6_e2m3
355389
}
356390
}
357391

358-
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 3 + 127) : 0;
392+
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 1 + 127) : 0;
359393

360-
const float d = GGML_E8M0_TO_FP32_HALF(e);
394+
const float d = GGML_E8M0_TO_FP32_ANY(e, 3);
361395

362396
y[i].e = e;
363397

@@ -378,40 +412,6 @@ void quantize_row_mxfp6_e2m3_ref(const float * GGML_RESTRICT x, block_mxfp6_e2m3
378412
}
379413
}
380414

381-
void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
382-
static const int qk = QK_MXFP4;
383-
384-
assert(k % qk == 0);
385-
386-
const int nb = k / qk;
387-
388-
for (int i = 0; i < nb; i++) {
389-
float amax = 0.0f; // absolute max
390-
391-
for (int j = 0; j < qk; j++) {
392-
const float v = x[i*qk + j];
393-
394-
if (amax < fabsf(v)) {
395-
amax = fabsf(v);
396-
}
397-
}
398-
399-
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
400-
401-
const float d = GGML_E8M0_TO_FP32_HALF(e);
402-
403-
y[i].e = e;
404-
405-
for (int j = 0; j < qk/2; ++j) {
406-
const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
407-
const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
408-
409-
y[i].qs[j] = x0;
410-
y[i].qs[j] |= x1 << 4;
411-
}
412-
}
413-
}
414-
415415
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
416416
static const int qk = QK4_0;
417417

@@ -550,7 +550,7 @@ void dequantize_row_mxfp6_e3m2(const block_mxfp6_e3m2 * GGML_RESTRICT x, float *
550550
const int nb = k / qk;
551551

552552
for (int i = 0; i < nb; i++) {
553-
const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
553+
const float d = GGML_E8M0_TO_FP32_ANY(x[i].e, 4);
554554

555555
for (int j = 0; j < qk / 4; ++j) {
556556
const int16_t x0 = kvalues_mxfp6_e3m2[x[i].qs[3 * j] & 0x3F];
@@ -574,7 +574,7 @@ void dequantize_row_mxfp6_e2m3(const block_mxfp6_e2m3 * GGML_RESTRICT x, float *
574574
const int nb = k / qk;
575575

576576
for (int i = 0; i < nb; i++) {
577-
const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
577+
const float d = GGML_E8M0_TO_FP32_ANY(x[i].e, 3);
578578

579579
for (int j = 0; j < qk / 4; ++j) {
580580
const int16_t x0 = kvalues_mxfp6_e2m3[x[i].qs[3 * j] & 0x3F];

gguf-py/gguf/quants.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
import numpy as np
1212

1313

14-
# see ggml_e8m0_to_fp32_half in ggml-impl.h
15-
def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
16-
bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
14+
def e8m0_to_fp32_any(x: np.ndarray, e: np.uint32) -> np.ndarray:
15+
bits = np.where(x < e + 1, np.uint32(1) << np.uint32(x - e + 22), np.uint32(x - e) << np.uint32(23))
1716
return bits.view(np.float32)
1817

19-
2018
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
2119
block_size, type_size = GGML_QUANT_SIZES[quant_type]
2220
if shape[-1] % block_size != 0:
@@ -666,7 +664,7 @@ class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
666664

667665
@classmethod
668666
def __e8m0_to_fp32_half(cls, x: np.ndarray) -> np.ndarray:
669-
e8m0_to_fp32_half(x)
667+
return e8m0_to_fp32_any(x, 1)
670668

671669
@classmethod
672670
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
@@ -721,8 +719,8 @@ class MXFP6E3M2(__Quant, qtype=GGMLQuantizationType.MXFP6_E3M2):
721719
)
722720

723721
@classmethod
724-
def __e8m0_to_fp32_half(cls, x: np.ndarray) -> np.ndarray:
725-
e8m0_to_fp32_half(x)
722+
def __e8m0_to_fp32_scaled(cls, x: np.ndarray) -> np.ndarray:
723+
return e8m0_to_fp32_any(x, 4)
726724

727725
@classmethod
728726
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
@@ -732,12 +730,12 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
732730

733731
with np.errstate(divide="ignore"):
734732
# convert log2(d_max) to e8m0
735-
e = np.where(d_max > 0, np.floor(np.log2(d_max)) - 4 + 127, 0).astype(
733+
e = np.where(d_max > 0, np.floor(np.log2(d_max)) - 9 + 127, 0).astype(
736734
np.uint8
737735
)
738736

739737
# d is float of above e8m0
740-
d = cls.__e8m0_to_fp32_half(e)
738+
d = cls.__e8m0_to_fp32_scaled(e)
741739

742740
kvalues = np.array(cls.kvalues, dtype=np.int16).reshape((1, 1, 64))
743741

@@ -774,7 +772,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
774772

775773
e, qs = np.hsplit(blocks, [1])
776774

777-
d = cls.__e8m0_to_fp32_half(e).astype(np.float32)
775+
d = cls.__e8m0_to_fp32_scaled(e).astype(np.float32)
778776

779777
qs_groups = qs.reshape((n_blocks, -1, 3))
780778
b0 = qs_groups[..., 0]
@@ -808,8 +806,8 @@ class MXFP6E2M3(__Quant, qtype=GGMLQuantizationType.MXFP6_E2M3):
808806
)
809807

810808
@classmethod
811-
def __e8m0_to_fp32_half(cls, x: np.ndarray) -> np.ndarray:
812-
e8m0_to_fp32_half(x)
809+
def __e8m0_to_fp32_scaled(cls, x: np.ndarray) -> np.ndarray:
810+
return e8m0_to_fp32_any(x, 3)
813811

814812
@classmethod
815813
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
@@ -819,12 +817,12 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
819817

820818
with np.errstate(divide="ignore"):
821819
# convert log2(d_max) to e8m0
822-
e = np.where(d_max > 0, np.floor(np.log2(d_max)) - 3 + 127, 0).astype(
820+
e = np.where(d_max > 0, np.floor(np.log2(d_max)) - 6 + 127, 0).astype(
823821
np.uint8
824822
)
825823

826824
# d is float of above e8m0
827-
d = cls.__e8m0_to_fp32_half(e)
825+
d = cls.__e8m0_to_fp32_scaled(e)
828826

829827
kvalues = np.array(cls.kvalues, dtype=np.int16).reshape((1, 1, 64))
830828

@@ -861,7 +859,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
861859

862860
e, qs = np.hsplit(blocks, [1])
863861

864-
d = cls.__e8m0_to_fp32_half(e).astype(np.float32)
862+
d = cls.__e8m0_to_fp32_scaled(e).astype(np.float32)
865863

866864
qs_groups = qs.reshape((n_blocks, -1, 3))
867865
b0 = qs_groups[..., 0]

0 commit comments

Comments
 (0)