Skip to content

Commit c1e3f10

Browse files
author
katsu560
committed
rename to gguf_add_file.py
1 parent 20c186c commit c1e3f10

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed

examples/yolo/gguf_add_file.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
#!/usr/bin/env python3
2+
# gguf_add_file.py srcfile dstfile addfiles ...
3+
4+
from __future__ import annotations
5+
6+
import logging
7+
import argparse
8+
import os
9+
import sys
10+
from pathlib import Path
11+
from typing import Any
12+
13+
import numpy as np
14+
import numpy.typing as npt
15+
16+
# Necessary to load the local gguf package
17+
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
18+
sys.path.insert(0, str(Path(__file__).parent.parent))
19+
20+
from gguf import GGUFReader, GGUFWriter, ReaderField, GGMLQuantizationType, GGUFEndian, GGUFValueType, Keys # noqa: E402
21+
22+
logger = logging.getLogger("ggufi_add_file")
23+
24+
25+
def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:
26+
host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG'
27+
if reader.byte_order == 'S':
28+
file_endian = 'BIG' if host_endian == 'LITTLE' else 'LITTLE'
29+
else:
30+
file_endian = host_endian
31+
return (host_endian, file_endian)
32+
33+
34+
def get_byteorder(reader: GGUFReader) -> GGUFEndian:
35+
if np.uint32(1) == np.uint32(1).newbyteorder("<"):
36+
# Host is little endian
37+
host_endian = GGUFEndian.LITTLE
38+
swapped_endian = GGUFEndian.BIG
39+
else:
40+
# Sorry PDP or other weird systems that don't use BE or LE.
41+
host_endian = GGUFEndian.BIG
42+
swapped_endian = GGUFEndian.LITTLE
43+
44+
if reader.byte_order == "S":
45+
return swapped_endian
46+
else:
47+
return host_endian
48+
49+
50+
def decode_field(field: ReaderField) -> Any:
51+
if field and field.types:
52+
main_type = field.types[0]
53+
54+
if main_type == GGUFValueType.ARRAY:
55+
sub_type = field.types[-1]
56+
57+
if sub_type == GGUFValueType.STRING:
58+
return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
59+
else:
60+
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
61+
if main_type == GGUFValueType.STRING:
62+
return str(bytes(field.parts[-1]), encoding='utf8')
63+
else:
64+
return field.parts[-1][0]
65+
66+
return None
67+
68+
69+
def get_field_data(reader: GGUFReader, key: str) -> Any:
70+
field = reader.get_field(key)
71+
72+
return decode_field(field)
73+
74+
75+
def copy_with_filename(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, filename: str[Any]) -> None:
76+
logger.debug(f'copy_with_filename: {filename}') #debug
77+
val = filename
78+
for field in reader.fields.values():
79+
# Suppress virtual fields and fields written by GGUFWriter
80+
if field.name == Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
81+
logger.debug(f'Suppressing {field.name}')
82+
continue
83+
84+
# Copy existed fields except 'embedded_files'
85+
if not field.name == Keys.EMBEDDED_FILES:
86+
cur_val = decode_field(field)
87+
writer.add_key(field.name)
88+
writer.add_val(cur_val, field.types[0])
89+
logger.debug(f'Copying {field.name}')
90+
continue
91+
92+
# Update embedded_files
93+
val = decode_field(field)
94+
for path in filename:
95+
logger.debug(f'Adding {field.name}: {path}')
96+
val.append(path)
97+
98+
# Add filenames to kv
99+
logger.info(f'* Modifying {Keys.EMBEDDED_FILES} to {val}')
100+
writer.add_array(Keys.EMBEDDED_FILES, val)
101+
102+
for tensor in reader.tensors:
103+
# Dimensions are written in reverse order, so flip them first
104+
shape = np.flipud(tensor.shape)
105+
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
106+
107+
# Add file info as tensor_info
108+
for path in filename:
109+
logger.debug(f'Adding tensor_info {path}')
110+
with open(path, "rb") as f:
111+
data = f.read()
112+
data_len = len(data)
113+
dims = [data_len]
114+
raw_dtype = GGMLQuantizationType.I8
115+
writer.add_tensor_info(path, dims, np.float16, data_len, raw_dtype)
116+
117+
writer.write_header_to_file()
118+
writer.write_kv_data_to_file()
119+
writer.write_ti_data_to_file()
120+
121+
for tensor in reader.tensors:
122+
writer.write_tensor_data(tensor.data)
123+
124+
# Write file body as tensor data
125+
for path in filename:
126+
logger.debug(f'Adding tensor data {path}')
127+
with open(path, "rb") as f:
128+
data = f.read()
129+
data_len = len(data)
130+
# write data with padding
131+
writer.write_data(data)
132+
133+
writer.close()
134+
135+
136+
def main() -> None:
137+
parser = argparse.ArgumentParser(description="Add files to GGUF file metadata")
138+
parser.add_argument("input", type=str, help="GGUF format model input filename")
139+
parser.add_argument("output", type=str, help="GGUF format model output filename")
140+
parser.add_argument("addfiles", type=str, nargs='+', help="add filenames ...")
141+
parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
142+
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
143+
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
144+
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
145+
146+
logger.info(f'* Loading: {args.input}')
147+
reader = GGUFReader(args.input, 'r')
148+
arch = get_field_data(reader, Keys.General.ARCHITECTURE)
149+
endianess = get_byteorder(reader)
150+
151+
if os.path.isfile(args.output) and not args.force:
152+
logger.warning('*** Warning *** Warning *** Warning **')
153+
logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
154+
logger.warning('* Enter exactly YES if you are positive you want to proceed:')
155+
response = input('YES, I am sure> ')
156+
if response != 'YES':
157+
logger.info("You didn't enter YES. Okay then, see ya!")
158+
sys.exit(0)
159+
160+
logger.info(f'* Writing: {args.output}')
161+
writer = GGUFWriter(args.output, arch=arch, endianess=endianess)
162+
163+
alignment = get_field_data(reader, Keys.General.ALIGNMENT)
164+
if alignment is not None:
165+
logger.debug(f'Setting custom alignment: {alignment}')
166+
writer.data_alignment = alignment
167+
168+
if args.addfiles is not None:
169+
filename = []
170+
for path in args.addfiles:
171+
filename.append(path)
172+
logger.info(f'* Adding: {path}')
173+
copy_with_filename(reader, writer, filename)
174+
175+
176+
if __name__ == '__main__':
177+
main()

0 commit comments

Comments
 (0)