Skip to content

Commit e54d41b

Browse files
authored
gguf-py : add Numpy MXFP4 de/quantization support (#15111)
* gguf-py : add MXFP4 de/quantization support * ggml-quants : handle zero amax for MXFP4
1 parent 4850b52 commit e54d41b

File tree

3 files changed

+68
-10
lines changed

3 files changed

+68
-10
lines changed

ggml/src/ggml-quants.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE
288288
}
289289
}
290290

291-
const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
291+
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
292292

293293
const float d = GGML_E8M0_TO_FP32_HALF(e);
294294

gguf-py/gguf/quants.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
228228
d = max / -8
229229
with np.errstate(divide="ignore"):
230230
id = np.where(d == 0, 0, 1 / d)
231-
# FIXME: Q4_0's reference rounding is cursed and depends on FMA
232-
qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
231+
qs = np.trunc((blocks * id) + np.float32(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
233232

234233
qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
235234
qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
@@ -300,8 +299,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
300299
d = max / -16
301300
with np.errstate(divide="ignore"):
302301
id = np.where(d == 0, 0, 1 / d)
303-
# FIXME: Q5_0's reference rounding is cursed and depends on FMA
304-
q = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
302+
q = np.trunc((blocks * id) + np.float32(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
305303

306304
qs = q.reshape((n_blocks, 2, cls.block_size // 2))
307305
qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
@@ -655,6 +653,57 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
655653
return (d * qs.astype(np.float32))
656654

657655

656+
class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
657+
# e2m1 values (doubled)
658+
# ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
659+
kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)
660+
661+
@staticmethod
662+
# see ggml_e8m0_to_fp32_half in ggml-impl.h
663+
def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
664+
bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
665+
return bits.view(np.float32)
666+
667+
@classmethod
668+
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
669+
n_blocks = blocks.shape[0]
670+
671+
d = abs(blocks).max(axis=-1, keepdims=True)
672+
673+
with np.errstate(divide="ignore"):
674+
e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
675+
676+
d = cls.e8m0_to_fp32_half(e)
677+
678+
kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16))
679+
680+
errs = np.abs(d.reshape((n_blocks, 1, 1)) * kvalues.astype(np.float32) - blocks.reshape((n_blocks, cls.block_size, 1)))
681+
best = np.argmin(errs, axis=-1, keepdims=True)
682+
683+
qs = best.reshape(n_blocks, 2, cls.block_size // 2).astype(np.uint8)
684+
qs = qs[:, 0] | (qs[:, 1] << np.uint8(4))
685+
686+
qs = qs.reshape((n_blocks, cls.block_size // 2))
687+
688+
return np.concatenate([e, qs], axis=-1)
689+
690+
@classmethod
691+
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
692+
n_blocks = blocks.shape[0]
693+
694+
e, qs = np.hsplit(blocks, [1])
695+
696+
d = cls.e8m0_to_fp32_half(e)
697+
698+
qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
699+
qs = (qs & np.uint8(0x0F)).view(np.int8)
700+
701+
kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
702+
qs = np.take_along_axis(kvalues, qs, axis=-1).reshape((n_blocks, cls.block_size))
703+
704+
return (d * qs.astype(np.float32))
705+
706+
658707
class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
659708
ksigns: bytes = (
660709
b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"

gguf-py/tests/test_quants.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, libggml: Path):
6767
"q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
6868
"q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
6969
"tq1_0", "tq2_0",
70+
"mxfp4",
7071
"iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
7172
"iq4_nl", "iq4_xs",
7273
):
@@ -140,14 +141,21 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType)
140141
return False
141142

142143

143-
def do_test(libggml_path: Path, quick: bool = False):
144+
def do_test(libggml_path: Path, quick: bool = False, user_type: GGMLQuantizationType | None = None):
144145
ggml_quants = GGMLQuants(libggml_path)
145146

146147
np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
147148

148149
r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
149-
150-
for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()):
150+
# test zero blocks
151+
r[0, 0, :] = 0
152+
## Maybe test infinities? (can make NANs, not really useful in practice)
153+
# r[0, 1, 0] = np.inf
154+
# r[0, 2, 0] = -np.inf
155+
# r[0, 3, 0] = np.inf
156+
# r[0, 3, 1] = -np.inf
157+
158+
for qtype in ((GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()) if user_type is None else (user_type,)):
151159
has_dequantize = False
152160
has_quantize = False
153161

@@ -228,11 +236,12 @@ def do_test(libggml_path: Path, quick: bool = False):
228236

229237
if __name__ == "__main__":
230238
parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
231-
parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so")
239+
parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "bin" / "libggml.so", help="The path to libggml.so")
232240
parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
241+
parser.add_argument("--type", type=str, help="The quant type to test (all by default)")
233242

234243
args = parser.parse_args()
235244

236245
logging.basicConfig(level=logging.DEBUG)
237246

238-
do_test(args.libggml, args.quick)
247+
do_test(args.libggml, args.quick, GGMLQuantizationType[args.type.upper()] if args.type is not None else None)

0 commit comments

Comments
 (0)