@@ -228,8 +228,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
228
228
d = max / - 8
229
229
with np .errstate (divide = "ignore" ):
230
230
id = np .where (d == 0 , 0 , 1 / d )
231
- # FIXME: Q4_0's reference rounding is cursed and depends on FMA
232
- qs = np .trunc ((np .float64 (blocks ) * np .float64 (id )) + np .float64 (8.5 ), dtype = np .float32 ).astype (np .uint8 ).clip (0 , 15 )
231
+ qs = np .trunc ((blocks * id ) + np .float32 (8.5 ), dtype = np .float32 ).astype (np .uint8 ).clip (0 , 15 )
233
232
234
233
qs = qs .reshape ((n_blocks , 2 , cls .block_size // 2 ))
235
234
qs = qs [..., 0 , :] | (qs [..., 1 , :] << np .uint8 (4 ))
@@ -300,8 +299,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
300
299
d = max / - 16
301
300
with np .errstate (divide = "ignore" ):
302
301
id = np .where (d == 0 , 0 , 1 / d )
303
- # FIXME: Q5_0's reference rounding is cursed and depends on FMA
304
- q = np .trunc ((np .float64 (blocks ) * np .float64 (id )) + np .float64 (16.5 ), dtype = np .float32 ).astype (np .uint8 ).clip (0 , 31 )
302
+ q = np .trunc ((blocks * id ) + np .float32 (16.5 ), dtype = np .float32 ).astype (np .uint8 ).clip (0 , 31 )
305
303
306
304
qs = q .reshape ((n_blocks , 2 , cls .block_size // 2 ))
307
305
qs = (qs [..., 0 , :] & np .uint8 (0x0F )) | (qs [..., 1 , :] << np .uint8 (4 ))
@@ -655,6 +653,57 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
655
653
return (d * qs .astype (np .float32 ))
656
654
657
655
656
+ class MXFP4 (__Quant , qtype = GGMLQuantizationType .MXFP4 ):
657
+ # e2m1 values (doubled)
658
+ # ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
659
+ kvalues = (0 , 1 , 2 , 3 , 4 , 6 , 8 , 12 , 0 , - 1 , - 2 , - 3 , - 4 , - 6 , - 8 , - 12 )
660
+
661
+ @staticmethod
662
+ # see ggml_e8m0_to_fp32_half in ggml-impl.h
663
+ def e8m0_to_fp32_half (x : np .ndarray ) -> np .ndarray :
664
+ bits = np .where (x < 2 , np .uint32 (0x00200000 ) << np .uint32 (x ), np .uint32 (x - 1 ) << np .uint32 (23 ))
665
+ return bits .view (np .float32 )
666
+
667
+ @classmethod
668
+ def quantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
669
+ n_blocks = blocks .shape [0 ]
670
+
671
+ d = abs (blocks ).max (axis = - 1 , keepdims = True )
672
+
673
+ with np .errstate (divide = "ignore" ):
674
+ e = np .where (d > 0 , np .floor (np .log2 (d )) - 2 + 127 , 0 ).astype (np .uint8 )
675
+
676
+ d = cls .e8m0_to_fp32_half (e )
677
+
678
+ kvalues = np .array (cls .kvalues , dtype = np .int8 ).reshape ((1 , 1 , 16 ))
679
+
680
+ errs = np .abs (d .reshape ((n_blocks , 1 , 1 )) * kvalues .astype (np .float32 ) - blocks .reshape ((n_blocks , cls .block_size , 1 )))
681
+ best = np .argmin (errs , axis = - 1 , keepdims = True )
682
+
683
+ qs = best .reshape (n_blocks , 2 , cls .block_size // 2 ).astype (np .uint8 )
684
+ qs = qs [:, 0 ] | (qs [:, 1 ] << np .uint8 (4 ))
685
+
686
+ qs = qs .reshape ((n_blocks , cls .block_size // 2 ))
687
+
688
+ return np .concatenate ([e , qs ], axis = - 1 )
689
+
690
+ @classmethod
691
+ def dequantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
692
+ n_blocks = blocks .shape [0 ]
693
+
694
+ e , qs = np .hsplit (blocks , [1 ])
695
+
696
+ d = cls .e8m0_to_fp32_half (e )
697
+
698
+ qs = qs .reshape ((n_blocks , 1 , cls .block_size // 2 )) >> np .array ([0 , 4 ], dtype = np .uint8 ).reshape ((1 , 2 , 1 ))
699
+ qs = (qs & np .uint8 (0x0F )).view (np .int8 )
700
+
701
+ kvalues = np .array (cls .kvalues , dtype = np .int8 ).reshape (1 , 1 , 16 )
702
+ qs = np .take_along_axis (kvalues , qs , axis = - 1 ).reshape ((n_blocks , cls .block_size ))
703
+
704
+ return (d * qs .astype (np .float32 ))
705
+
706
+
658
707
class IQ2_XXS (__Quant , qtype = GGMLQuantizationType .IQ2_XXS ):
659
708
ksigns : bytes = (
660
709
b"\x00 \x81 \x82 \x03 \x84 \x05 \x06 \x87 \x88 \x09 \x0a \x8b \x0c \x8d \x8e \x0f "
0 commit comments