Skip to content

Commit 29ca8f2

Browse files
committed
py quants: wip q4_k and q3_k quantize
1 parent 1f63e75 commit 29ca8f2

File tree

2 files changed

+201
-27
lines changed

2 files changed

+201
-27
lines changed

gguf-py/gguf/quants.py

Lines changed: 192 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,88 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
431431

432432

433433
class Q3_K(__Quant, qtype=GGMLQuantizationType.Q3_K):
434+
@classmethod
435+
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
436+
n_blocks = blocks.shape[0]
437+
438+
# 1. Sub-block scaling
439+
sub_blocks = blocks.reshape((n_blocks, 16, 16))
440+
441+
# Find value with abs max for each sub-block to determine scale
442+
abs_sub_blocks = np.abs(sub_blocks)
443+
amax_indices = np.argmax(abs_sub_blocks, axis=-1, keepdims=True)
444+
max_vals = np.take_along_axis(sub_blocks, amax_indices, axis=-1).squeeze(-1)
445+
446+
# For 3-bit quantization [-4, 3], the max absolute quant is 4
447+
with np.errstate(divide="ignore", invalid="ignore"):
448+
scales = np.where(max_vals != 0, max_vals / 4.0, 0)
449+
450+
# 2. Block-level scale (d)
451+
abs_scales = np.abs(scales)
452+
amax_indices_s = np.argmax(abs_scales, axis=-1, keepdims=True)
453+
max_scale = np.take_along_axis(scales, amax_indices_s, axis=-1)
454+
455+
# 3. Quantize and pack scales
456+
with np.errstate(divide="ignore", invalid="ignore"):
457+
iscale = np.where(max_scale == 0, 0, -32.0 / max_scale)
458+
459+
# Quantize scales to 6-bit signed (-32 to 31), then shift to unsigned (0 to 63).
460+
# Ensure the final type is uint8 to prevent casting errors later.
461+
l = (np.clip(np_roundf(scales * iscale), -32, 31) + 32).astype(np.uint8)
462+
463+
# Pack the 16 6-bit values into 12 bytes
464+
scales_packed = np.zeros((n_blocks, 12), dtype=np.uint8)
465+
l_low = l & 0x0F
466+
l_high = (l >> 4) & 0x03
467+
468+
scales_packed[:, 0:8] = l_low[:, 0:8]
469+
scales_packed[:, 0:8] |= (l_low[:, 8:16] << 4)
470+
471+
l_high_reshaped = l_high.reshape(n_blocks, 4, 4).transpose(0, 2, 1)
472+
packed_high_bits = l_high_reshaped[:, :, 0] | \
473+
(l_high_reshaped[:, :, 1] << 2) | \
474+
(l_high_reshaped[:, :, 2] << 4) | \
475+
(l_high_reshaped[:, :, 3] << 6)
476+
scales_packed[:, 8:12] = packed_high_bits
477+
478+
# 4. Store block-level d
479+
with np.errstate(divide="ignore", invalid="ignore"):
480+
d_val = np.where(max_scale == 0, 0, max_scale / -32.0)
481+
d = d_val.astype(np.float16).view(np.uint8)
482+
483+
# 5. Re-quantize data
484+
# Dequantize scales to get effective scale for each sub-block
485+
l_dequant = (l.astype(np.int8) - 32).astype(np.float32)
486+
d_eff = (d_val * l_dequant).reshape(n_blocks, 16, 1)
487+
488+
with np.errstate(divide="ignore", invalid="ignore"):
489+
l_data_float = np.where(d_eff == 0, 0, sub_blocks / d_eff)
490+
491+
# Quantize data to 3-bit signed [-4, 3], then shift to unsigned [0, 7].
492+
# Ensure the final type is uint8.
493+
l_data = (np.clip(np_roundf(l_data_float), -4, 3) + 4).astype(np.uint8)
494+
495+
# 6. Pack quants (qs and hmask)
496+
l_data = l_data.reshape(n_blocks, 256)
497+
498+
# hmask stores the 3rd bit
499+
hmask_values = (l_data > 3).reshape(n_blocks, 8, 32).transpose(0, 2, 1)
500+
hmask = np.packbits(hmask_values, axis=-1, bitorder='little')
501+
# Reshape hmask from (n_blocks, 32, 1) to (n_blocks, 32)
502+
hmask = hmask.reshape(n_blocks, -1)
503+
504+
# qs stores the lower 2 bits
505+
l_data_low = (l_data & 0x03).reshape(n_blocks, 2, 4, 32)
506+
507+
qs_parts = l_data_low[:, :, 0, :] | \
508+
(l_data_low[:, :, 1, :] << 2) | \
509+
(l_data_low[:, :, 2, :] << 4) | \
510+
(l_data_low[:, :, 3, :] << 6)
511+
qs = qs_parts.reshape(n_blocks, 64)
512+
513+
# Final assembly in the order expected by dequantize
514+
return np.concatenate([hmask, qs, scales_packed, d], axis=1)
515+
434516
@classmethod
435517
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
436518
n_blocks = blocks.shape[0]
@@ -476,31 +558,111 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
476558

477559
class Q4_K(__Quant, qtype=GGMLQuantizationType.Q4_K):
478560
K_SCALE_SIZE = 12
561+
QK_K = QK_K # Block size
562+
563+
@classmethod
564+
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
565+
"""
566+
Quantizes a numpy array of floats into Q4_K format.
567+
The last dimension of the input array must be a multiple of QK_K.
568+
569+
Args:
570+
x: A numpy array of floats.
571+
572+
Returns:
573+
A numpy array of uint8 with the quantized data, shaped (n_blocks, block_size_in_bytes).
574+
"""
575+
x = blocks
576+
if x.shape[-1] % cls.QK_K != 0:
577+
raise ValueError(f"The last dimension of the input array must be a multiple of {cls.QK_K}, but got {x.shape[-1]}")
578+
579+
# Reshape the input into blocks of QK_K
580+
n_blocks = x.size // cls.QK_K
581+
blocks = x.reshape((n_blocks, cls.QK_K))
582+
583+
# Reshape to sub-blocks of 32 for processing
584+
sub_blocks = blocks.reshape((n_blocks, 8, 32))
585+
586+
# --- 1. Find per-sub-block scales and mins ---
587+
# The dequantization is x_approx = d * q - m, so q = (x_approx + m) / d.
588+
# We define m = -min(x) to ensure m is positive.
589+
mins_val = -np.min(sub_blocks, axis=-1, keepdims=True)
590+
maxs_val = np.max(sub_blocks, axis=-1, keepdims=True)
591+
592+
# Calculate scales for 4-bit quantization (range 0-15)
593+
scales_val = (maxs_val + mins_val) / 15.0
594+
# Prevent division by zero if all values in a sub-block are the same
595+
scales_val[scales_val < 1e-8] = 1.0
596+
597+
# --- 2. Find block-level d and dmin ---
598+
max_scale = np.max(scales_val, axis=1, keepdims=True)
599+
max_min = np.max(mins_val, axis=1, keepdims=True)
600+
601+
# --- 3. Quantize and pack the scales and mins ---
602+
# Quantize the scales and mins to 6-bit integers (0-63)
603+
with np.errstate(divide="ignore"):
604+
inv_scale = np.where(max_scale == 0, 0, 63.0 / max_scale)
605+
inv_min = np.where(max_min == 0, 0, 63.0 / max_min)
606+
607+
ls = np.clip(np_roundf(scales_val * inv_scale), 0, 63).astype(np.uint8)
608+
lm = np.clip(np_roundf(mins_val * inv_min), 0, 63).astype(np.uint8)
609+
610+
# Pack the 6-bit scales (ls) and mins (lm) into a 12-byte array per block
611+
scales_packed = np.zeros((n_blocks, cls.K_SCALE_SIZE), dtype=np.uint8)
612+
# Lower 4 bits of ls[4..7] and lm[4..7]
613+
scales_packed[:, 8:12] = (ls[:, 4:8, 0] & 0x0F) | ((lm[:, 4:8, 0] & 0x0F) << 4)
614+
# Main part of ls[0..3] and lm[0..3]
615+
scales_packed[:, 0:4] = ls[:, 0:4, 0] & 0x3F
616+
scales_packed[:, 4:8] = lm[:, 0:4, 0] & 0x3F
617+
# Higher 2 bits of ls and lm
618+
scales_packed[:, 0:4] |= (ls[:, 4:8, 0] >> 4) << 6
619+
scales_packed[:, 4:8] |= (lm[:, 4:8, 0] >> 4) << 6
620+
621+
# --- 4. Store block-level d and dmin as fp16 ---
622+
with np.errstate(divide="ignore"):
623+
d_val = np.where(max_scale == 0, 0, max_scale / 63.0)
624+
dmin_val = np.where(max_min == 0, 0, max_min / 63.0)
625+
626+
# d_val and dmin_val have shape (n_blocks, 1, 1). We reshape to (n_blocks, 1)
627+
# before viewing as uint8 to get a final 2D shape of (n_blocks, 2) for concatenation.
628+
d = d_val.reshape(n_blocks, 1).astype(np.float16).view(np.uint8)
629+
dmin = dmin_val.reshape(n_blocks, 1).astype(np.float16).view(np.uint8)
630+
631+
# --- 5. Quantize the actual data ---
632+
# Reconstruct effective scales and mins for each sub-block using original d_val and dmin_val
633+
d_eff = d_val * ls.astype(np.float32)
634+
m_eff = dmin_val * lm.astype(np.float32)
635+
636+
# q = round((x + m_eff) / d_eff)
637+
with np.errstate(divide="ignore"):
638+
L = np.where(d_eff == 0, 0, (sub_blocks + m_eff) / d_eff)
639+
640+
L = np.clip(np_roundf(L), 0, 15).astype(np.uint8)
641+
642+
# Pack the 4-bit quantized data (L) into the `qs` array
643+
L_reshaped = L.reshape((n_blocks, QK_K // 64, 2, 32))
644+
L_low = L_reshaped[:, :, 0, :].reshape(n_blocks, -1)
645+
L_high = L_reshaped[:, :, 1, :].reshape(n_blocks, -1)
646+
qs = L_low | (L_high << 4)
647+
648+
# --- 6. Assemble and return the final block ---
649+
return np.concatenate([d, dmin, scales_packed, qs], axis=1)
479650

480651
@staticmethod
481652
def get_scale_min(scales: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
482653
n_blocks = scales.shape[0]
483-
scales = scales.view(np.uint8)
484-
### Unpacking the following: ###
485-
# 0 EEAAAAAA
486-
# 1 FFBBBBBB
487-
# 2 GGCCCCCC
488-
# 3 HHDDDDDD
489-
# 4 eeaaaaaa
490-
# 5 ffbbbbbb
491-
# 6 ggcccccc
492-
# 7 hhdddddd
493-
# 8 eeeeEEEE
494-
# 9 ffffFFFF
495-
# 10 ggggGGGG
496-
# 11 hhhhHHHH
497-
scales = scales.reshape((n_blocks, 3, 4))
498-
d, m, m_d = np.split(scales, 3, axis=-2)
499-
500-
sc = np.concatenate([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], axis=-1)
501-
min = np.concatenate([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], axis=-1)
502-
503-
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
654+
s = scales.view(np.uint8).reshape(n_blocks, Q4_K.K_SCALE_SIZE)
655+
656+
sc = np.zeros((n_blocks, 8), dtype=np.uint8)
657+
m = np.zeros((n_blocks, 8), dtype=np.uint8)
658+
659+
sc[:, 0:4] = s[:, 0:4] & 0x3F
660+
m[:, 0:4] = s[:, 4:8] & 0x3F
661+
662+
sc[:, 4:8] = (s[:, 8:12] & 0x0F) | ((s[:, 0:4] >> 6) << 4)
663+
m[:, 4:8] = (s[:, 8:12] >> 4) | ((s[:, 4:8] >> 6) << 4)
664+
665+
return sc, m
504666

505667
@classmethod
506668
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
@@ -513,15 +675,18 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
513675
d = d.view(np.float16).astype(np.float32)
514676
dmin = dmin.view(np.float16).astype(np.float32)
515677

516-
sc, m = Q4_K.get_scale_min(scales)
678+
sc, m = cls.get_scale_min(scales)
517679

518-
d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1))
519-
dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1))
680+
d_eff = (d * sc.astype(np.float32)).reshape((n_blocks, 8, 1))
681+
dm_eff = (dmin * m.astype(np.float32)).reshape((n_blocks, 8, 1))
520682

521-
qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1))
522-
qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 32)).astype(np.float32)
683+
# Unpack 4-bit values and arrange back into sub-blocks
684+
qs_reshaped = qs.reshape(n_blocks, QK_K // 64, 32)
685+
qs_unpacked = np.empty((n_blocks, 8, 32), dtype=np.float32)
686+
qs_unpacked[:, [0, 2, 4, 6], :] = (qs_reshaped & 0x0F)
687+
qs_unpacked[:, [1, 3, 5, 7], :] = (qs_reshaped >> 4)
523688

524-
return (d * qs - dm).reshape((n_blocks, QK_K))
689+
return (d_eff * qs_unpacked - dm_eff).reshape((n_blocks, QK_K))
525690

526691

527692
class Q5_K(__Quant, qtype=GGMLQuantizationType.Q5_K):

gguf-py/tests/test_quants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,15 @@ def do_test(libggml_path: Path, quick: bool = False):
225225
else:
226226
logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅")
227227

228+
if has_quantize and has_dequantize:
229+
pyq = gguf.quants.quantize(rc, qtype)
230+
pydq = gguf.quants.dequantize(pyq, qtype)
231+
py_mse_loss = np.mean((rc - pydq) ** 2)
232+
ggq = ggml_quants.quantize(rc, qtype)
233+
ggdq = ggml_quants.dequantize(ggq, qtype)
234+
gg_mse_loss = np.mean((rc - ggdq) ** 2)
235+
logger.info(f"MSE loss for {qtype.name} quant vs dequant: python = {py_mse_loss:.6f}, gg = {gg_mse_loss:.6f}")
236+
228237

229238
if __name__ == "__main__":
230239
parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")

0 commit comments

Comments
 (0)