@@ -15,6 +15,7 @@ def e8m0_to_fp32_any(x: np.ndarray, e: np.uint32) -> np.ndarray:
1515 bits = np .where (x < e + 1 , np .uint32 (1 ) << np .uint32 (x - e + 22 ), np .uint32 (x - e ) << np .uint32 (23 ))
1616 return bits .view (np .float32 )
1717
18+
1819def quant_shape_to_byte_shape (shape : Sequence [int ], quant_type : GGMLQuantizationType ) -> tuple [int , ...]:
1920 block_size , type_size = GGML_QUANT_SIZES [quant_type ]
2021 if shape [- 1 ] % block_size != 0 :
@@ -730,9 +731,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
730731
731732 with np .errstate (divide = "ignore" ):
732733 # convert log2(d_max) to e8m0
733- e = np .where (d_max > 0 , np .floor (np .log2 (d_max )) - 9 + 127 , 0 ).astype (
734- np .uint8
735- )
734+ e = np .where (d_max > 0 , np .floor (np .log2 (d_max )) - 3 + 127 , 0 ).astype (np .uint8 )
736735
737736 # d is float of above e8m0
738737 d = cls .__e8m0_to_fp32_scaled (e )
@@ -817,9 +816,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
817816
818817 with np .errstate (divide = "ignore" ):
819818 # convert log2(d_max) to e8m0
820- e = np .where (d_max > 0 , np .floor (np .log2 (d_max )) - 6 + 127 , 0 ).astype (
821- np .uint8
822- )
819+ e = np .where (d_max > 0 , np .floor (np .log2 (d_max )) - 1 + 127 , 0 ).astype (np .uint8 )
823820
824821 # d is float of above e8m0
825822 d = cls .__e8m0_to_fp32_scaled (e )
0 commit comments