Skip to content

Commit e18593c

Browse files
author
katsu560
committed
refactor code, fix copying key value, add --force
1 parent 695fbaf commit e18593c

File tree

1 file changed

+38
-35
lines changed

1 file changed

+38
-35
lines changed

examples/yolo/gguf-addfile.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -73,47 +73,41 @@ def get_field_data(reader: GGUFReader, key: str) -> Any:
7373
return decode_field(field)
7474

7575

76-
def copy_with_filename(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], filename: str[Any]) -> None:
76+
def copy_with_filename(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, filename: str[Any]) -> None:
77+
logger.debug(f'copy_with_filename: {filename}') #debug
78+
val = filename
7779
for field in reader.fields.values():
7880
# Suppress virtual fields and fields written by GGUFWriter
7981
if field.name == Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
8082
logger.debug(f'Suppressing {field.name}')
8183
continue
8284

83-
# Skip old chat templates if we have new ones
84-
if field.name.startswith(Keys.Tokenizer.CHAT_TEMPLATE) and Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
85-
logger.debug(f'Skipping {field.name}')
86-
continue
87-
88-
old_val = decode_field(field)
89-
val = new_metadata.get(field.name, old_val)
90-
91-
if field.name in new_metadata:
92-
logger.debug(f'Modifying {field.name}: "{old_val}" -> "{val}"')
93-
del new_metadata[field.name]
94-
elif val is not None:
95-
logger.debug(f'Copying {field.name}')
96-
97-
if val is not None:
85+
# Copy existed fields except 'embedded_files'
86+
if not field.name == Keys.EMBEDDED_FILES:
87+
cur_val = decode_field(field)
9888
writer.add_key(field.name)
99-
writer.add_val(val, field.types[0])
89+
writer.add_val(cur_val, field.types[0])
90+
logger.debug(f'Copying {field.name}')
91+
continue
10092

101-
if Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
102-
logger.debug('Adding chat template(s)')
103-
writer.add_chat_template(new_metadata[Keys.Tokenizer.CHAT_TEMPLATE])
104-
del new_metadata[Keys.Tokenizer.CHAT_TEMPLATE]
93+
# Update embedded_files
94+
val = decode_field(field)
95+
for path in filename:
96+
logger.debug(f'Adding {field.name}: {path}')
97+
val.append(path)
10598

106-
# add filenames to kv
107-
writer.add_array(Keys.EMBEDDED_FILES, filename)
99+
# Add filenames to kv
100+
logger.info(f'* Modifying {Keys.EMBEDDED_FILES} to {val}')
101+
writer.add_array(Keys.EMBEDDED_FILES, val)
108102

109103
for tensor in reader.tensors:
110104
# Dimensions are written in reverse order, so flip them first
111105
shape = np.flipud(tensor.shape)
112106
writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
113107

114-
# add file info as tensor_info
108+
# Add file info as tensor_info
115109
for path in filename:
116-
logger.debug(f'Adding {path}')
110+
logger.debug(f'Adding tensor_info {path}')
117111
with open(path, "rb") as f:
118112
data = f.read()
119113
data_len = len(data)
@@ -128,9 +122,9 @@ def copy_with_filename(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_met
128122
for tensor in reader.tensors:
129123
writer.write_tensor_data(tensor.data)
130124

131-
# write file body as tensor data
125+
# Write file body as tensor data
132126
for path in filename:
133-
logger.debug(f'Adding {path}')
127+
logger.debug(f'Adding tensor data {path}')
134128
with open(path, "rb") as f:
135129
data = f.read()
136130
data_len = len(data)
@@ -145,6 +139,7 @@ def main() -> None:
145139
parser.add_argument("input", type=str, help="GGUF format model input filename")
146140
parser.add_argument("output", type=str, help="GGUF format model output filename")
147141
parser.add_argument("addfiles", type=str, nargs='+', help="add filenames ...")
142+
parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
148143
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
149144
args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])
150145
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
@@ -154,6 +149,15 @@ def main() -> None:
154149
arch = get_field_data(reader, Keys.General.ARCHITECTURE)
155150
endianess = get_byteorder(reader)
156151

152+
if os.path.isfile(args.output) and not args.force:
153+
logger.warning('*** Warning *** Warning *** Warning **')
154+
logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
155+
logger.warning('* Enter exactly YES if you are positive you want to proceed:')
156+
response = input('YES, I am sure> ')
157+
if response != 'YES':
158+
logger.info("You didn't enter YES. Okay then, see ya!")
159+
sys.exit(0)
160+
157161
logger.info(f'* Writing: {args.output}')
158162
writer = GGUFWriter(args.output, arch=arch, endianess=endianess)
159163

@@ -162,14 +166,13 @@ def main() -> None:
162166
logger.debug(f'Setting custom alignment: {alignment}')
163167
writer.data_alignment = alignment
164168

165-
logger.info(f'* Adding: {args.addfiles}')
166-
new_metadata = {}
167-
filename = []
168-
for path in args.addfiles:
169-
filename.append(path)
170-
logger.info(f'* Adding: {path}')
171-
copy_with_filename(reader, writer, new_metadata, filename)
172-
169+
if args.addfiles is not None:
170+
filename = []
171+
for path in args.addfiles:
172+
filename.append(path)
173+
logger.info(f'* Adding: {path}')
174+
copy_with_filename(reader, writer, filename)
175+
173176

174177
if __name__ == '__main__':
175178
main()

0 commit comments

Comments
 (0)