@@ -228,8 +228,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
228228 d = max / - 8
229229 with np .errstate (divide = "ignore" ):
230230 id = np .where (d == 0 , 0 , 1 / d )
231- # FIXME: Q4_0's reference rounding is cursed and depends on FMA
232- qs = np .trunc ((np .float64 (blocks ) * np .float64 (id )) + np .float64 (8.5 ), dtype = np .float32 ).astype (np .uint8 ).clip (0 , 15 )
231+ qs = np .trunc ((blocks * id ) + np .float32 (8.5 ), dtype = np .float32 ).astype (np .uint8 ).clip (0 , 15 )
233232
234233 qs = qs .reshape ((n_blocks , 2 , cls .block_size // 2 ))
235234 qs = qs [..., 0 , :] | (qs [..., 1 , :] << np .uint8 (4 ))
@@ -300,8 +299,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
300299 d = max / - 16
301300 with np .errstate (divide = "ignore" ):
302301 id = np .where (d == 0 , 0 , 1 / d )
303- # FIXME: Q5_0's reference rounding is cursed and depends on FMA
304- q = np .trunc ((np .float64 (blocks ) * np .float64 (id )) + np .float64 (16.5 ), dtype = np .float32 ).astype (np .uint8 ).clip (0 , 31 )
302+ q = np .trunc ((blocks * id ) + np .float32 (16.5 ), dtype = np .float32 ).astype (np .uint8 ).clip (0 , 31 )
305303
306304 qs = q .reshape ((n_blocks , 2 , cls .block_size // 2 ))
307305 qs = (qs [..., 0 , :] & np .uint8 (0x0F )) | (qs [..., 1 , :] << np .uint8 (4 ))
@@ -655,6 +653,57 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
655653 return (d * qs .astype (np .float32 ))
656654
657655
656+ class MXFP4 (__Quant , qtype = GGMLQuantizationType .MXFP4 ):
657+ # e2m1 values (doubled)
658+ # ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
659+ kvalues = (0 , 1 , 2 , 3 , 4 , 6 , 8 , 12 , 0 , - 1 , - 2 , - 3 , - 4 , - 6 , - 8 , - 12 )
660+
661+ @staticmethod
662+ # see ggml_e8m0_to_fp32_half in ggml-impl.h
663+ def e8m0_to_fp32_half (x : np .ndarray ) -> np .ndarray :
664+ bits = np .where (x < 2 , np .uint32 (0x00200000 ) << np .uint32 (x ), np .uint32 (x - 1 ) << np .uint32 (23 ))
665+ return bits .view (np .float32 )
666+
667+ @classmethod
668+ def quantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
669+ n_blocks = blocks .shape [0 ]
670+
671+ d = abs (blocks ).max (axis = - 1 , keepdims = True )
672+
673+ with np .errstate (divide = "ignore" ):
674+ e = np .where (d > 0 , np .floor (np .log2 (d )) - 2 + 127 , 0 ).astype (np .uint8 )
675+
676+ d = cls .e8m0_to_fp32_half (e )
677+
678+ kvalues = np .array (cls .kvalues , dtype = np .int8 ).reshape ((1 , 1 , 16 ))
679+
680+ errs = np .abs (d .reshape ((n_blocks , 1 , 1 )) * kvalues .astype (np .float32 ) - blocks .reshape ((n_blocks , cls .block_size , 1 )))
681+ best = np .argmin (errs , axis = - 1 , keepdims = True )
682+
683+ qs = best .reshape (n_blocks , 2 , cls .block_size // 2 ).astype (np .uint8 )
684+ qs = qs [:, 0 ] | (qs [:, 1 ] << np .uint8 (4 ))
685+
686+ qs = qs .reshape ((n_blocks , cls .block_size // 2 ))
687+
688+ return np .concatenate ([e , qs ], axis = - 1 )
689+
690+ @classmethod
691+ def dequantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
692+ n_blocks = blocks .shape [0 ]
693+
694+ e , qs = np .hsplit (blocks , [1 ])
695+
696+ d = cls .e8m0_to_fp32_half (e )
697+
698+ qs = qs .reshape ((n_blocks , 1 , cls .block_size // 2 )) >> np .array ([0 , 4 ], dtype = np .uint8 ).reshape ((1 , 2 , 1 ))
699+ qs = (qs & np .uint8 (0x0F )).view (np .int8 )
700+
701+ kvalues = np .array (cls .kvalues , dtype = np .int8 ).reshape (1 , 1 , 16 )
702+ qs = np .take_along_axis (kvalues , qs , axis = - 1 ).reshape ((n_blocks , cls .block_size ))
703+
704+ return (d * qs .astype (np .float32 ))
705+
706+
658707class IQ2_XXS (__Quant , qtype = GGMLQuantizationType .IQ2_XXS ):
659708 ksigns : bytes = (
660709 b"\x00 \x81 \x82 \x03 \x84 \x05 \x06 \x87 \x88 \x09 \x0a \x8b \x0c \x8d \x8e \x0f "
0 commit comments