|
| 1 | +import gguf |
| 2 | +import argparse |
| 3 | +import logging |
| 4 | +import torch |
| 5 | +from typing import Union |
| 6 | +from pathlib import Path |
| 7 | +from torch import Tensor |
| 8 | +from transformers import MimiModel, PreTrainedModel |
| 9 | + |
| 10 | +logger = logging.getLogger("mimi") |
| 11 | + |
| 12 | + |
| 13 | +class MimiModelConverter: |
| 14 | + mimi_model: PreTrainedModel |
| 15 | + gguf_writer: gguf.GGUFWriter |
| 16 | + fname_out: Path |
| 17 | + ftype: gguf.LlamaFileType |
| 18 | + |
| 19 | + def __init__(self, |
| 20 | + pretrained_model_name_or_path: Union[Path, str], |
| 21 | + fname_out: Path, |
| 22 | + ftype: gguf.LlamaFileType, |
| 23 | + is_big_endian: bool,): |
| 24 | + self.mimi_model = MimiModel.from_pretrained(pretrained_model_name_or_path) |
| 25 | + self.fname_out = fname_out |
| 26 | + self.ftype = ftype |
| 27 | + endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE |
| 28 | + self.gguf_writer = gguf.GGUFWriter( |
| 29 | + path=None, |
| 30 | + arch="if you see this, you are using the wrong file", |
| 31 | + endianess=endianess) |
| 32 | + |
| 33 | + assert self.mimi_model.config.architectures[0] == "MimiModel" |
| 34 | + |
| 35 | + # load tensors |
| 36 | + for name, data_torch in self.mimi_model.state_dict().items(): |
| 37 | + # convert any unsupported data types to float32 |
| 38 | + old_dtype = data_torch.dtype |
| 39 | + if data_torch.dtype not in (torch.float16, torch.float32): |
| 40 | + data_torch = data_torch.to(torch.float32) |
| 41 | + self.add_tensor(name, data_torch, old_dtype) |
| 42 | + |
| 43 | + def add_tensor(self, name: str, data_torch: Tensor, old_dtype: torch.dtype): |
| 44 | + is_1d = len(data_torch.shape) == 1 |
| 45 | + is_bias = ".bias" in name |
| 46 | + can_quantize = not is_1d and not is_bias |
| 47 | + data_qtype = gguf.GGMLQuantizationType.F32 |
| 48 | + |
| 49 | + n_head = self.mimi_model.config.num_attention_heads |
| 50 | + n_kv_head = self.mimi_model.config.num_key_value_heads |
| 51 | + if name.endswith(("q_proj.weight", "q_proj.bias")): |
| 52 | + data_torch = self.undo_permute(data_torch, n_head, n_head) |
| 53 | + if name.endswith(("k_proj.weight", "k_proj.bias")): |
| 54 | + data_torch = self.undo_permute(data_torch, n_head, n_kv_head) |
| 55 | + |
| 56 | + # process codebook |
| 57 | + if ".codebook.initialized" in name: |
| 58 | + # "initialized" tensor |
| 59 | + state_dict = self.mimi_model.state_dict() |
| 60 | + embed_sum = state_dict[name.replace(".initialized", ".embed_sum")] |
| 61 | + cluster_usage = state_dict[name.replace(".initialized", ".cluster_usage")] |
| 62 | + # see modeling_mimi.py --> MimiEuclideanCodebook |
| 63 | + data_torch = embed_sum / cluster_usage.clamp(min=self.mimi_model.config.norm_eps)[:, None] |
| 64 | + name = name.replace(".initialized", "") |
| 65 | + |
| 66 | + # ignore processed tensors |
| 67 | + if ".cluster_usage" in name or ".embed_sum" in name: |
| 68 | + return |
| 69 | + |
| 70 | + # transpose some tensors |
| 71 | + if ".conv.bias" in name: |
| 72 | + data_torch = data_torch.view((1, data_torch.shape[0])) |
| 73 | + data_torch = data_torch.transpose(0, 1) |
| 74 | + |
| 75 | + # change view 3d to 2d |
| 76 | + if "quantizer" in name and "_proj." in name: |
| 77 | + assert data_torch.shape[2] == 1 |
| 78 | + data_torch = data_torch.view((data_torch.shape[0], data_torch.shape[1])) |
| 79 | + |
| 80 | + # shorten name, otherwise it will be too long for ggml to read |
| 81 | + name = name.replace("_residual_vector_quantizer", "_rvq") |
| 82 | + |
| 83 | + if can_quantize: |
| 84 | + if self.ftype == gguf.LlamaFileType.ALL_F32: |
| 85 | + data_qtype = gguf.GGMLQuantizationType.F32 |
| 86 | + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: |
| 87 | + data_qtype = gguf.GGMLQuantizationType.F16 |
| 88 | + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: |
| 89 | + data_qtype = gguf.GGMLQuantizationType.BF16 |
| 90 | + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: |
| 91 | + data_qtype = gguf.GGMLQuantizationType.Q8_0 |
| 92 | + else: |
| 93 | + raise ValueError(f"Unsupported file type: {self.ftype}") |
| 94 | + |
| 95 | + # Conv kernels are always F16 |
| 96 | + if ".conv.weight" in name: |
| 97 | + data_qtype = gguf.GGMLQuantizationType.F16 |
| 98 | + |
| 99 | + data = data_torch.numpy() |
| 100 | + |
| 101 | + try: |
| 102 | + data = gguf.quants.quantize(data, data_qtype) |
| 103 | + except Exception as e: |
| 104 | + logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16") |
| 105 | + data_qtype = gguf.GGMLQuantizationType.F16 |
| 106 | + data = gguf.quants.quantize(data, data_qtype) |
| 107 | + |
| 108 | + # reverse shape to make it similar to the internal ggml dimension order |
| 109 | + shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}" |
| 110 | + logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") |
| 111 | + |
| 112 | + self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype) |
| 113 | + |
| 114 | + def write(self): |
| 115 | + self.gguf_writer.write_header_to_file(path=self.fname_out) |
| 116 | + self.gguf_writer.write_kv_data_to_file() |
| 117 | + self.gguf_writer.write_tensors_to_file(progress=True) |
| 118 | + self.gguf_writer.close() |
| 119 | + |
| 120 | + @staticmethod |
| 121 | + def undo_permute(weights: Tensor, n_head: int, n_head_kv: int): |
| 122 | + if n_head_kv is not None and n_head != n_head_kv: |
| 123 | + n_head = n_head_kv |
| 124 | + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) |
| 125 | + .swapaxes(1, 2) |
| 126 | + .reshape(weights.shape)) |
| 127 | + |
| 128 | +def parse_args() -> argparse.Namespace: |
| 129 | + parser = argparse.ArgumentParser( |
| 130 | + description="Convert Mimi safetensors model to GGUF",) |
| 131 | + parser.add_argument( |
| 132 | + "--outfile", type=Path, default="kyutai-mimi.gguf", |
| 133 | + help="path to write to", |
| 134 | + ) |
| 135 | + parser.add_argument( |
| 136 | + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", |
| 137 | + help="output format", |
| 138 | + ) |
| 139 | + parser.add_argument( |
| 140 | + "--bigendian", action="store_true", |
| 141 | + help="model is executed on big endian machine", |
| 142 | + ) |
| 143 | + parser.add_argument( |
| 144 | + "model", type=Path, |
| 145 | + help="directory or model ID containing model file (if model ID is specified, download from Hugging Face hub)", |
| 146 | + nargs="?", |
| 147 | + default="kyutai/mimi", |
| 148 | + ) |
| 149 | + parser.add_argument( |
| 150 | + "--verbose", action="store_true", |
| 151 | + help="increase output verbosity", |
| 152 | + ) |
| 153 | + |
| 154 | + args = parser.parse_args() |
| 155 | + if args.model is None: |
| 156 | + parser.error("the following arguments are required: model") |
| 157 | + return args |
| 158 | + |
| 159 | + |
| 160 | +def main() -> None: |
| 161 | + args = parse_args() |
| 162 | + |
| 163 | + if args.verbose: |
| 164 | + logging.basicConfig(level=logging.DEBUG) |
| 165 | + else: |
| 166 | + logging.basicConfig(level=logging.INFO) |
| 167 | + |
| 168 | + dir_model = args.model |
| 169 | + |
| 170 | + ftype_map: dict[str, gguf.LlamaFileType] = { |
| 171 | + "f32": gguf.LlamaFileType.ALL_F32, |
| 172 | + "f16": gguf.LlamaFileType.MOSTLY_F16, |
| 173 | + "bf16": gguf.LlamaFileType.MOSTLY_BF16, |
| 174 | + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, |
| 175 | + } |
| 176 | + |
| 177 | + logger.info(f"Loading model: {dir_model}") |
| 178 | + |
| 179 | + with torch.inference_mode(): |
| 180 | + converter = MimiModelConverter( |
| 181 | + pretrained_model_name_or_path=dir_model, |
| 182 | + fname_out=args.outfile, |
| 183 | + ftype=ftype_map[args.outtype], |
| 184 | + is_big_endian=args.bigendian, |
| 185 | + ) |
| 186 | + converter.write() |
| 187 | + |
| 188 | + |
| 189 | +if __name__ == '__main__': |
| 190 | + main() |
| 191 | + |
0 commit comments