Skip to content

Commit c8a5504

Browse files
committed
reapply changes after sync with main branch
1 parent 9e396b3 commit c8a5504

File tree

2 files changed

+224
-2
lines changed

2 files changed

+224
-2
lines changed

llama_cpp/_internals.py

Lines changed: 215 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import os
4+
import sys
5+
import struct
46
import ctypes
57

68
from typing import (
@@ -10,6 +12,8 @@
1012
)
1113
from dataclasses import dataclass, field
1214
from contextlib import ExitStack
15+
from io import BufferedReader
16+
from enum import IntEnum
1317

1418
import numpy as np
1519
import numpy.typing as npt
@@ -226,7 +230,7 @@ def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
226230
)
227231

228232
# Extra
229-
def metadata(self) -> Dict[str, str]:
233+
def _metadata_no_arrays(self) -> Dict[str, str]:
230234
assert self.model is not None
231235
metadata: Dict[str, str] = {}
232236
buffer_size = 1024
@@ -250,6 +254,12 @@ def metadata(self) -> Dict[str, str]:
250254
metadata[key] = value
251255
return metadata
252256

257+
def metadata(self) -> Dict[str, Union[str, int, float, bool, list]]:
258+
assert self.model is not None
259+
# Uncomment the next line to use the old method
260+
#return self._metadata_no_arrays()
261+
return QuickGGUFReader.load_metadata(self.path_model)
262+
253263
@staticmethod
254264
def default_params():
255265
"""Get the default llama_model_params."""
@@ -833,3 +843,207 @@ def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool):
833843
if apply_grammar and self.grammar is not None:
834844
ctx_main.grammar_accept_token(self.grammar, id)
835845
self.prev.append(id)
846+
847+
class QuickGGUFReader:
848+
"""
849+
All logic in this class is based on the GGUF format specification, which
850+
can be found here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
851+
"""
852+
# NOTE: Officially, there is no way to determine if a GGUF file is little
853+
# or big endian. The format specifcation directs us to assume that
854+
# a file is little endian in all cases unless additional info is
855+
# provided.
856+
#
857+
# In addition to this, GGUF files cannot run on hosts with the
858+
# opposite endianness. And, at this point in the code, the model
859+
# is already loaded. Therefore, we can assume that the endianness
860+
# of the file is the same as the endianness of the host.
861+
862+
# the GGUF format versions that this class supports
863+
SUPPORTED_GGUF_VERSIONS = [2, 3]
864+
865+
# GGUF only supports execution on little or big endian machines
866+
if sys.byteorder not in ['little', 'big']:
867+
raise ValueError(
868+
"host is not little or big endian - GGUF is unsupported"
869+
)
870+
871+
# Occasionally check to ensure these values are consistent with
872+
# the latest values in llama.cpp/gguf-py/gguf/constants.py
873+
class GGUFValueType(IntEnum):
874+
UINT8 = 0
875+
INT8 = 1
876+
UINT16 = 2
877+
INT16 = 3
878+
UINT32 = 4
879+
INT32 = 5
880+
FLOAT32 = 6
881+
BOOL = 7
882+
STRING = 8
883+
ARRAY = 9
884+
UINT64 = 10
885+
INT64 = 11
886+
FLOAT64 = 12
887+
888+
# arguments for struct.unpack() based on gguf value type
889+
value_packing: dict = {
890+
GGUFValueType.UINT8: "=B",
891+
GGUFValueType.INT8: "=b",
892+
GGUFValueType.UINT16: "=H",
893+
GGUFValueType.INT16: "=h",
894+
GGUFValueType.UINT32: "=I",
895+
GGUFValueType.INT32: "=i",
896+
GGUFValueType.FLOAT32: "=f",
897+
GGUFValueType.UINT64: "=Q",
898+
GGUFValueType.INT64: "=q",
899+
GGUFValueType.FLOAT64: "=d",
900+
GGUFValueType.BOOL: "?"
901+
}
902+
903+
# length in bytes for each gguf value type
904+
value_lengths: dict = {
905+
GGUFValueType.UINT8: 1,
906+
GGUFValueType.INT8: 1,
907+
GGUFValueType.UINT16: 2,
908+
GGUFValueType.INT16: 2,
909+
GGUFValueType.UINT32: 4,
910+
GGUFValueType.INT32: 4,
911+
GGUFValueType.FLOAT32: 4,
912+
GGUFValueType.UINT64: 8,
913+
GGUFValueType.INT64: 8,
914+
GGUFValueType.FLOAT64: 8,
915+
GGUFValueType.BOOL: 1
916+
}
917+
918+
@staticmethod
919+
def unpack(value_type: GGUFValueType, file: BufferedReader):
920+
return struct.unpack(
921+
QuickGGUFReader.value_packing.get(value_type),
922+
file.read(QuickGGUFReader.value_lengths.get(value_type))
923+
)[0]
924+
925+
@staticmethod
926+
def get_single(
927+
value_type: GGUFValueType,
928+
file: BufferedReader
929+
) -> Union[str, int, float, bool]:
930+
"""Read a single value from an open file"""
931+
if value_type == QuickGGUFReader.GGUFValueType.STRING:
932+
string_length = QuickGGUFReader.unpack(
933+
QuickGGUFReader.GGUFValueType.UINT64,
934+
file=file
935+
)
936+
value = file.read(string_length)
937+
# officially, strings that cannot be decoded into utf-8 are invalid
938+
value = value.decode("utf-8")
939+
else:
940+
value = QuickGGUFReader.unpack(value_type, file=file)
941+
return value
942+
943+
@staticmethod
944+
def load_metadata(
945+
fn: Union[os.PathLike[str], str]
946+
) -> dict[str, Union[str, int, float, bool, list]]:
947+
"""
948+
Given a path to a GGUF file, peek at its header for metadata
949+
950+
Return a dictionary where all keys are strings, and values can be
951+
strings, ints, floats, bools, or lists
952+
"""
953+
954+
metadata: dict[str, Union[str, int, float, bool, list]] = {}
955+
with open(fn, "rb") as file:
956+
magic = file.read(4)
957+
958+
if magic != b"GGUF":
959+
raise ValueError(
960+
"your model file is not a valid GGUF file "
961+
f"(magic number mismatch, got {magic}, "
962+
"expected b'GGUF')"
963+
)
964+
965+
version = QuickGGUFReader.unpack(
966+
QuickGGUFReader.GGUFValueType.UINT32,
967+
file=file
968+
)
969+
970+
if version not in QuickGGUFReader.SUPPORTED_GGUF_VERSIONS:
971+
raise ValueError(
972+
f"your model file reports GGUF version {version}, but "
973+
f"only versions {QuickGGUFReader.SUPPORTED_GGUF_VERSIONS} "
974+
"are supported. re-convert your model or download a newer "
975+
"version"
976+
)
977+
978+
tensor_count = QuickGGUFReader.unpack(
979+
QuickGGUFReader.GGUFValueType.UINT64,
980+
file=file
981+
)
982+
983+
if version == 3:
984+
metadata_kv_count = QuickGGUFReader.unpack(
985+
QuickGGUFReader.GGUFValueType.UINT64,
986+
file=file
987+
)
988+
elif version == 2:
989+
metadata_kv_count = QuickGGUFReader.unpack(
990+
QuickGGUFReader.GGUFValueType.UINT32,
991+
file=file
992+
)
993+
994+
for _ in range(metadata_kv_count):
995+
if version == 3:
996+
key_length = QuickGGUFReader.unpack(
997+
QuickGGUFReader.GGUFValueType.UINT64,
998+
file=file
999+
)
1000+
elif version == 2:
1001+
key_length = 0
1002+
while key_length == 0:
1003+
# read until next key is found
1004+
key_length = QuickGGUFReader.unpack(
1005+
QuickGGUFReader.GGUFValueType.UINT32,
1006+
file=file
1007+
)
1008+
file.read(4) # 4 byte offset for GGUFv2
1009+
key = file.read(key_length)
1010+
value_type = QuickGGUFReader.GGUFValueType(
1011+
QuickGGUFReader.unpack(
1012+
QuickGGUFReader.GGUFValueType.UINT32,
1013+
file=file
1014+
)
1015+
)
1016+
if value_type == QuickGGUFReader.GGUFValueType.ARRAY:
1017+
array_value_type = QuickGGUFReader.GGUFValueType(
1018+
QuickGGUFReader.unpack(
1019+
QuickGGUFReader.GGUFValueType.UINT32,
1020+
file=file
1021+
)
1022+
)
1023+
# array_length is the number of items in the array
1024+
if version == 3:
1025+
array_length = QuickGGUFReader.unpack(
1026+
QuickGGUFReader.GGUFValueType.UINT64,
1027+
file=file
1028+
)
1029+
elif version == 2:
1030+
array_length = QuickGGUFReader.unpack(
1031+
QuickGGUFReader.GGUFValueType.UINT32,
1032+
file=file
1033+
)
1034+
file.read(4) # 4 byte offset for GGUFv2
1035+
array = [
1036+
QuickGGUFReader.get_single(
1037+
array_value_type,
1038+
file=file
1039+
) for _ in range(array_length)
1040+
]
1041+
metadata[key.decode()] = array
1042+
else:
1043+
value = QuickGGUFReader.get_single(
1044+
value_type,
1045+
file=file
1046+
)
1047+
metadata[key.decode()] = value
1048+
1049+
return metadata

llama_cpp/llama.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,15 @@ def __init__(
428428
print(f"Failed to load metadata: {e}", file=sys.stderr)
429429

430430
if self.verbose:
431-
print(f"Model metadata: {self.metadata}", file=sys.stderr)
431+
print("Model metadata:", file=sys.stderr)
432+
for k, v in self.metadata.items():
433+
# only calculate repr() once as it may be slow for large arrays
434+
repr_v = repr(v)
435+
if len(repr_v) > 63:
436+
# truncate long values
437+
print(f" {k}: {repr_v[:60]}...", file=sys.stderr)
438+
else:
439+
print(f" {k}: {repr_v}", file=sys.stderr)
432440

433441
eos_token_id = self.token_eos()
434442
bos_token_id = self.token_bos()

0 commit comments

Comments
 (0)