Skip to content

Commit b175e6a

Browse files
author
horasal
committed
solve flake8 and pyright error
1 parent eb0f6c0 commit b175e6a

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

gguf-py/gguf/quants.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
import numpy as np
1212

13+
1314
# see ggml_e8m0_to_fp32_half in ggml-impl.h
1415
def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
1516
bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
1617
return bits.view(np.float32)
1718

19+
1820
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
1921
block_size, type_size = GGML_QUANT_SIZES[quant_type]
2022
if shape[-1] % block_size != 0:
@@ -662,6 +664,10 @@ class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
662664
# ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
663665
kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)
664666

667+
@classmethod
668+
def __e8m0_to_fp32_half(cls, x: np.ndarray) -> np.ndarray:
669+
e8m0_to_fp32_half(x)
670+
665671
@classmethod
666672
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
667673
n_blocks = blocks.shape[0]
@@ -671,7 +677,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
671677
with np.errstate(divide="ignore"):
672678
e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
673679

674-
d = cls.e8m0_to_fp32_half(e)
680+
d = cls.__e8m0_to_fp32_half(e)
675681

676682
kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16))
677683

@@ -691,7 +697,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
691697

692698
e, qs = np.hsplit(blocks, [1])
693699

694-
d = cls.e8m0_to_fp32_half(e)
700+
d = cls.__e8m0_to_fp32_half(e)
695701

696702
qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
697703
qs = (qs & np.uint8(0x0F)).view(np.int8)
@@ -701,6 +707,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
701707

702708
return (d * qs.astype(np.float32))
703709

710+
704711
class MXFP6E3M2(__Quant, qtype=GGMLQuantizationType.MXFP6_E3M2):
705712
# e3m2 values (origin * 16)
706713
# ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
@@ -713,6 +720,10 @@ class MXFP6E3M2(__Quant, qtype=GGMLQuantizationType.MXFP6_E3M2):
713720
-48, -56, -64, -80, -96, -112, -128, -160, -192, -224, -256, -320, -384, -448,
714721
)
715722

723+
@classmethod
724+
def __e8m0_to_fp32_half(cls, x: np.ndarray) -> np.ndarray:
725+
e8m0_to_fp32_half(x)
726+
716727
@classmethod
717728
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
718729
n_blocks = blocks.shape[0]
@@ -726,7 +737,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
726737
)
727738

728739
# d is float of above e8m0
729-
d = cls.e8m0_to_fp32_half(e)
740+
d = cls.__e8m0_to_fp32_half(e)
730741

731742
kvalues = np.array(cls.kvalues, dtype=np.int16).reshape((1, 1, 64))
732743

@@ -763,7 +774,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
763774

764775
e, qs = np.hsplit(blocks, [1])
765776

766-
d = cls.e8m0_to_fp32_half(e).astype(np.float32)
777+
d = cls.__e8m0_to_fp32_half(e).astype(np.float32)
767778

768779
qs_groups = qs.reshape((n_blocks, -1, 3))
769780
b0 = qs_groups[..., 0]
@@ -796,6 +807,10 @@ class MXFP6E2M3(__Quant, qtype=GGMLQuantizationType.MXFP6_E2M3):
796807
-22, -24, -26, -28, -30, -32, -36, -40, -44, -48, -52, -56, -60,
797808
)
798809

810+
@classmethod
811+
def __e8m0_to_fp32_half(cls, x: np.ndarray) -> np.ndarray:
812+
e8m0_to_fp32_half(x)
813+
799814
@classmethod
800815
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
801816
n_blocks = blocks.shape[0]
@@ -809,7 +824,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
809824
)
810825

811826
# d is float of above e8m0
812-
d = cls.e8m0_to_fp32_half(e)
827+
d = cls.__e8m0_to_fp32_half(e)
813828

814829
kvalues = np.array(cls.kvalues, dtype=np.int16).reshape((1, 1, 64))
815830

@@ -846,7 +861,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
846861

847862
e, qs = np.hsplit(blocks, [1])
848863

849-
d = cls.e8m0_to_fp32_half(e).astype(np.float32)
864+
d = cls.__e8m0_to_fp32_half(e).astype(np.float32)
850865

851866
qs_groups = qs.reshape((n_blocks, -1, 3))
852867
b0 = qs_groups[..., 0]

0 commit comments

Comments
 (0)