Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
GGML_TYPES = {
"F32": 0,
"Q4_0": 2,
"Q4_1": 3,
"Q8_0": 8,
"Q2_K": 10,
"Q3_K": 11,
Expand All @@ -52,6 +53,7 @@
"Q4_K": 144,
# Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales
"Q4_0": 2 + 16,
"Q4_1": 2 + 2 + 16,
"Q6_K": 210,
# See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
Expand Down Expand Up @@ -273,6 +275,36 @@ def dequantize_q4_0(data):
return (scales * quants).astype(np.float32)


def dequantize_q4_1(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1106
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L18
block_size = GGML_BLOCK_SIZES["Q4_1"]
num_blocks = len(data) // block_size

data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)

# The scales are stored on the first 2 bytes
scales = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
# scales = np.nan_to_num(scales)

# The mins are stored on the second 2 bytes
mins = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)

# the rest of the bytes corresponds to the quants - we discard the first four bytes
quants = data_u8[:, 4:]

ql = (quants[:, :] & 0xF).astype(np.int8)
qr = (quants[:, :] >> 4).astype(np.int8)

# Use hstack
quants = np.hstack([ql, qr])

return ((scales * quants) + mins).astype(np.float32)


def dequantize_q6_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275
Expand Down Expand Up @@ -493,6 +525,8 @@ def load_dequant_gguf_tensor(shape, ggml_type, data):
values = dequantize_q8_0(data)
elif ggml_type == GGML_TYPES["Q4_0"]:
values = dequantize_q4_0(data)
elif ggml_type == GGML_TYPES["Q4_1"]:
values = dequantize_q4_1(data)
elif ggml_type == GGML_TYPES["Q4_K"]:
values = dequantize_q4_k(data)
elif ggml_type == GGML_TYPES["Q6_K"]:
Expand Down