1010
1111import numpy as np
1212
13+
1314# see ggml_e8m0_to_fp32_half in ggml-impl.h
1415def e8m0_to_fp32_half (x : np .ndarray ) -> np .ndarray :
1516 bits = np .where (x < 2 , np .uint32 (0x00200000 ) << np .uint32 (x ), np .uint32 (x - 1 ) << np .uint32 (23 ))
1617 return bits .view (np .float32 )
1718
19+
1820def quant_shape_to_byte_shape (shape : Sequence [int ], quant_type : GGMLQuantizationType ) -> tuple [int , ...]:
1921 block_size , type_size = GGML_QUANT_SIZES [quant_type ]
2022 if shape [- 1 ] % block_size != 0 :
@@ -662,6 +664,10 @@ class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
662664 # ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
663665 kvalues = (0 , 1 , 2 , 3 , 4 , 6 , 8 , 12 , 0 , - 1 , - 2 , - 3 , - 4 , - 6 , - 8 , - 12 )
664666
667+ @classmethod
668+ def __e8m0_to_fp32_half (cls , x : np .ndarray ) -> np .ndarray :
669+ e8m0_to_fp32_half (x )
670+
665671 @classmethod
666672 def quantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
667673 n_blocks = blocks .shape [0 ]
@@ -671,7 +677,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
671677 with np .errstate (divide = "ignore" ):
672678 e = np .where (d > 0 , np .floor (np .log2 (d )) - 2 + 127 , 0 ).astype (np .uint8 )
673679
674- d = cls .e8m0_to_fp32_half (e )
680+ d = cls .__e8m0_to_fp32_half (e )
675681
676682 kvalues = np .array (cls .kvalues , dtype = np .int8 ).reshape ((1 , 1 , 16 ))
677683
@@ -691,7 +697,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
691697
692698 e , qs = np .hsplit (blocks , [1 ])
693699
694- d = cls .e8m0_to_fp32_half (e )
700+ d = cls .__e8m0_to_fp32_half (e )
695701
696702 qs = qs .reshape ((n_blocks , 1 , cls .block_size // 2 )) >> np .array ([0 , 4 ], dtype = np .uint8 ).reshape ((1 , 2 , 1 ))
697703 qs = (qs & np .uint8 (0x0F )).view (np .int8 )
@@ -701,6 +707,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
701707
702708 return (d * qs .astype (np .float32 ))
703709
710+
704711class MXFP6E3M2 (__Quant , qtype = GGMLQuantizationType .MXFP6_E3M2 ):
705712 # e3m2 values (origin * 16)
706713 # ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
@@ -713,6 +720,10 @@ class MXFP6E3M2(__Quant, qtype=GGMLQuantizationType.MXFP6_E3M2):
713720 - 48 , - 56 , - 64 , - 80 , - 96 , - 112 , - 128 , - 160 , - 192 , - 224 , - 256 , - 320 , - 384 , - 448 ,
714721 )
715722
723+ @classmethod
724+ def __e8m0_to_fp32_half (cls , x : np .ndarray ) -> np .ndarray :
725+ e8m0_to_fp32_half (x )
726+
716727 @classmethod
717728 def quantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
718729 n_blocks = blocks .shape [0 ]
@@ -726,7 +737,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
726737 )
727738
728739 # d is float of above e8m0
729- d = cls .e8m0_to_fp32_half (e )
740+ d = cls .__e8m0_to_fp32_half (e )
730741
731742 kvalues = np .array (cls .kvalues , dtype = np .int16 ).reshape ((1 , 1 , 64 ))
732743
@@ -763,7 +774,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
763774
764775 e , qs = np .hsplit (blocks , [1 ])
765776
766- d = cls .e8m0_to_fp32_half (e ).astype (np .float32 )
777+ d = cls .__e8m0_to_fp32_half (e ).astype (np .float32 )
767778
768779 qs_groups = qs .reshape ((n_blocks , - 1 , 3 ))
769780 b0 = qs_groups [..., 0 ]
@@ -796,6 +807,10 @@ class MXFP6E2M3(__Quant, qtype=GGMLQuantizationType.MXFP6_E2M3):
796807 - 22 , - 24 , - 26 , - 28 , - 30 , - 32 , - 36 , - 40 , - 44 , - 48 , - 52 , - 56 , - 60 ,
797808 )
798809
810+ @classmethod
811+ def __e8m0_to_fp32_half (cls , x : np .ndarray ) -> np .ndarray :
812+ e8m0_to_fp32_half (x )
813+
799814 @classmethod
800815 def quantize_blocks (cls , blocks : np .ndarray ) -> np .ndarray :
801816 n_blocks = blocks .shape [0 ]
@@ -809,7 +824,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
809824 )
810825
811826 # d is float of above e8m0
812- d = cls .e8m0_to_fp32_half (e )
827+ d = cls .__e8m0_to_fp32_half (e )
813828
814829 kvalues = np .array (cls .kvalues , dtype = np .int16 ).reshape ((1 , 1 , 64 ))
815830
@@ -846,7 +861,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
846861
847862 e , qs = np .hsplit (blocks , [1 ])
848863
849- d = cls .e8m0_to_fp32_half (e ).astype (np .float32 )
864+ d = cls .__e8m0_to_fp32_half (e ).astype (np .float32 )
850865
851866 qs_groups = qs .reshape ((n_blocks , - 1 , 3 ))
852867 b0 = qs_groups [..., 0 ]
0 commit comments