| 
 | 1 | +#!/usr/bin/env python3  | 
 | 2 | +# gguf_add_file.py srcfile dstfile addfiles ...  | 
 | 3 | + | 
 | 4 | +from __future__ import annotations  | 
 | 5 | + | 
 | 6 | +import logging  | 
 | 7 | +import argparse  | 
 | 8 | +import os  | 
 | 9 | +import sys  | 
 | 10 | +from pathlib import Path  | 
 | 11 | +from typing import Any  | 
 | 12 | + | 
 | 13 | +import numpy as np  | 
 | 14 | +import numpy.typing as npt  | 
 | 15 | + | 
 | 16 | +# Necessary to load the local gguf package  | 
 | 17 | +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():  | 
 | 18 | +    sys.path.insert(0, str(Path(__file__).parent.parent))  | 
 | 19 | + | 
 | 20 | +from gguf import GGUFReader, GGUFWriter, ReaderField, GGMLQuantizationType, GGUFEndian, GGUFValueType, Keys  # noqa: E402  | 
 | 21 | + | 
 | 22 | +logger = logging.getLogger("ggufi_add_file")  | 
 | 23 | + | 
 | 24 | + | 
 | 25 | +def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]:  | 
 | 26 | +    host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG'  | 
 | 27 | +    if reader.byte_order == 'S':  | 
 | 28 | +        file_endian = 'BIG' if host_endian == 'LITTLE' else 'LITTLE'  | 
 | 29 | +    else:  | 
 | 30 | +        file_endian = host_endian  | 
 | 31 | +    return (host_endian, file_endian)  | 
 | 32 | + | 
 | 33 | + | 
 | 34 | +def get_byteorder(reader: GGUFReader) -> GGUFEndian:  | 
 | 35 | +    if np.uint32(1) == np.uint32(1).newbyteorder("<"):  | 
 | 36 | +        # Host is little endian  | 
 | 37 | +        host_endian = GGUFEndian.LITTLE  | 
 | 38 | +        swapped_endian = GGUFEndian.BIG  | 
 | 39 | +    else:  | 
 | 40 | +        # Sorry PDP or other weird systems that don't use BE or LE.  | 
 | 41 | +        host_endian = GGUFEndian.BIG  | 
 | 42 | +        swapped_endian = GGUFEndian.LITTLE  | 
 | 43 | + | 
 | 44 | +    if reader.byte_order == "S":  | 
 | 45 | +        return swapped_endian  | 
 | 46 | +    else:  | 
 | 47 | +        return host_endian  | 
 | 48 | + | 
 | 49 | + | 
 | 50 | +def decode_field(field: ReaderField) -> Any:  | 
 | 51 | +    if field and field.types:  | 
 | 52 | +        main_type = field.types[0]  | 
 | 53 | + | 
 | 54 | +        if main_type == GGUFValueType.ARRAY:  | 
 | 55 | +            sub_type = field.types[-1]  | 
 | 56 | + | 
 | 57 | +            if sub_type == GGUFValueType.STRING:  | 
 | 58 | +                return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]  | 
 | 59 | +            else:  | 
 | 60 | +                return [pv for idx in field.data for pv in field.parts[idx].tolist()]  | 
 | 61 | +        if main_type == GGUFValueType.STRING:  | 
 | 62 | +            return str(bytes(field.parts[-1]), encoding='utf8')  | 
 | 63 | +        else:  | 
 | 64 | +            return field.parts[-1][0]  | 
 | 65 | + | 
 | 66 | +    return None  | 
 | 67 | + | 
 | 68 | + | 
 | 69 | +def get_field_data(reader: GGUFReader, key: str) -> Any:  | 
 | 70 | +    field = reader.get_field(key)  | 
 | 71 | + | 
 | 72 | +    return decode_field(field)  | 
 | 73 | + | 
 | 74 | + | 
 | 75 | +def copy_with_filename(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, filename: str[Any]) -> None:  | 
 | 76 | +    logger.debug(f'copy_with_filename: {filename}') #debug  | 
 | 77 | +    val = filename  | 
 | 78 | +    for field in reader.fields.values():  | 
 | 79 | +        # Suppress virtual fields and fields written by GGUFWriter  | 
 | 80 | +        if field.name == Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):  | 
 | 81 | +            logger.debug(f'Suppressing {field.name}')  | 
 | 82 | +            continue  | 
 | 83 | + | 
 | 84 | +        # Copy existed fields except 'embedded_files'  | 
 | 85 | +        if not field.name == Keys.EMBEDDED_FILES:  | 
 | 86 | +            cur_val = decode_field(field)  | 
 | 87 | +            writer.add_key(field.name)  | 
 | 88 | +            writer.add_val(cur_val, field.types[0])  | 
 | 89 | +            logger.debug(f'Copying {field.name}')  | 
 | 90 | +            continue  | 
 | 91 | + | 
 | 92 | +        # Update embedded_files  | 
 | 93 | +        val = decode_field(field)  | 
 | 94 | +        for path in filename:  | 
 | 95 | +            logger.debug(f'Adding {field.name}: {path}')  | 
 | 96 | +            val.append(path)  | 
 | 97 | + | 
 | 98 | +    # Add filenames to kv  | 
 | 99 | +    logger.info(f'* Modifying {Keys.EMBEDDED_FILES} to {val}')  | 
 | 100 | +    writer.add_array(Keys.EMBEDDED_FILES, val)  | 
 | 101 | +      | 
 | 102 | +    for tensor in reader.tensors:  | 
 | 103 | +        # Dimensions are written in reverse order, so flip them first  | 
 | 104 | +        shape = np.flipud(tensor.shape)  | 
 | 105 | +        writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)  | 
 | 106 | + | 
 | 107 | +    # Add file info as tensor_info  | 
 | 108 | +    for path in filename:  | 
 | 109 | +        logger.debug(f'Adding tensor_info {path}')  | 
 | 110 | +        with open(path, "rb") as f:  | 
 | 111 | +            data = f.read()  | 
 | 112 | +            data_len = len(data)  | 
 | 113 | +            dims = [data_len]  | 
 | 114 | +            raw_dtype = GGMLQuantizationType.I8  | 
 | 115 | +            writer.add_tensor_info(path, dims, np.float16, data_len, raw_dtype)  | 
 | 116 | + | 
 | 117 | +    writer.write_header_to_file()  | 
 | 118 | +    writer.write_kv_data_to_file()  | 
 | 119 | +    writer.write_ti_data_to_file()  | 
 | 120 | + | 
 | 121 | +    for tensor in reader.tensors:  | 
 | 122 | +        writer.write_tensor_data(tensor.data)  | 
 | 123 | + | 
 | 124 | +    # Write file body as tensor data  | 
 | 125 | +    for path in filename:  | 
 | 126 | +        logger.debug(f'Adding tensor data {path}')  | 
 | 127 | +        with open(path, "rb") as f:  | 
 | 128 | +            data = f.read()  | 
 | 129 | +            data_len = len(data)  | 
 | 130 | +            # write data with padding  | 
 | 131 | +            writer.write_data(data)  | 
 | 132 | + | 
 | 133 | +    writer.close()  | 
 | 134 | + | 
 | 135 | + | 
 | 136 | +def main() -> None:  | 
 | 137 | +    parser = argparse.ArgumentParser(description="Add files to GGUF file metadata")  | 
 | 138 | +    parser.add_argument("input",        type=str,            help="GGUF format model input filename")  | 
 | 139 | +    parser.add_argument("output",       type=str,            help="GGUF format model output filename")  | 
 | 140 | +    parser.add_argument("addfiles",     type=str, nargs='+', help="add filenames ...")  | 
 | 141 | +    parser.add_argument("--force",      action="store_true", help="Bypass warnings without confirmation")  | 
 | 142 | +    parser.add_argument("--verbose",    action="store_true", help="Increase output verbosity")  | 
 | 143 | +    args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"])  | 
 | 144 | +    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)  | 
 | 145 | + | 
 | 146 | +    logger.info(f'* Loading: {args.input}')  | 
 | 147 | +    reader = GGUFReader(args.input, 'r')  | 
 | 148 | +    arch = get_field_data(reader, Keys.General.ARCHITECTURE)  | 
 | 149 | +    endianess = get_byteorder(reader)  | 
 | 150 | + | 
 | 151 | +    if os.path.isfile(args.output) and not args.force:  | 
 | 152 | +        logger.warning('*** Warning *** Warning *** Warning **')  | 
 | 153 | +        logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')  | 
 | 154 | +        logger.warning('* Enter exactly YES if you are positive you want to proceed:')  | 
 | 155 | +        response = input('YES, I am sure> ')  | 
 | 156 | +        if response != 'YES':  | 
 | 157 | +            logger.info("You didn't enter YES. Okay then, see ya!")  | 
 | 158 | +            sys.exit(0)  | 
 | 159 | + | 
 | 160 | +    logger.info(f'* Writing: {args.output}')  | 
 | 161 | +    writer = GGUFWriter(args.output, arch=arch, endianess=endianess)  | 
 | 162 | + | 
 | 163 | +    alignment = get_field_data(reader, Keys.General.ALIGNMENT)  | 
 | 164 | +    if alignment is not None:  | 
 | 165 | +        logger.debug(f'Setting custom alignment: {alignment}')  | 
 | 166 | +        writer.data_alignment = alignment  | 
 | 167 | + | 
 | 168 | +    if args.addfiles is not None:  | 
 | 169 | +        filename = []  | 
 | 170 | +        for path in args.addfiles:  | 
 | 171 | +            filename.append(path)  | 
 | 172 | +            logger.info(f'* Adding: {path}')  | 
 | 173 | +        copy_with_filename(reader, writer, filename)  | 
 | 174 | +   | 
 | 175 | + | 
 | 176 | +if __name__ == '__main__':  | 
 | 177 | +    main()  | 
0 commit comments