|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import os |
| 4 | +from typing import Any, Dict, List, Optional, Tuple |
| 5 | + |
| 6 | +import bitsandbytes as bnb |
| 7 | +import click |
| 8 | +import torch |
| 9 | +from peft.tuners.lora import QuantLinear |
| 10 | +from safetensors.torch import save_file |
| 11 | +from tqdm import tqdm |
| 12 | +from transformers import AutoConfig, AutoModelForCausalLM |
| 13 | +from transformers.modeling_utils import PreTrainedModel |
| 14 | + |
| 15 | +from mergekit.card import generate_card_lora |
| 16 | +from mergekit.common import ModelReference |
| 17 | +from mergekit.io import LazyTensorLoader |
| 18 | +from mergekit.options import add_merge_options |
| 19 | + |
| 20 | + |
| 21 | +def _low_rank_decomposition( |
| 22 | + weight: torch.Tensor, reduced_rank: int = 16 |
| 23 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 24 | + """ |
| 25 | + Decompose a 2D matrix into low-rank matrices A and B using SVD.a |
| 26 | +
|
| 27 | + :param weight: The matrix to decompose, of shape (H, W) |
| 28 | + :param reduced_rank: The final rank of the decomposition |
| 29 | + :return: A tuple of tensors (A, B) |
| 30 | + """ |
| 31 | + if weight.dim() != 2: |
| 32 | + raise ValueError( |
| 33 | + f"Only support 2D matrix, but your input has {weight.dim()} dimensions." |
| 34 | + ) |
| 35 | + |
| 36 | + # SVD Decomposition |
| 37 | + U, S, Vh = torch.linalg.svd(weight, full_matrices=False) |
| 38 | + |
| 39 | + # Truncated matrices |
| 40 | + A = Vh[:reduced_rank, :] |
| 41 | + B = U[:, :reduced_rank] @ torch.diag(S[:reduced_rank]) |
| 42 | + |
| 43 | + return A, B |
| 44 | + |
| 45 | + |
| 46 | +def decompose_delta_weight( |
| 47 | + new_weight: torch.Tensor, |
| 48 | + base_weight: torch.Tensor, |
| 49 | + reduced_rank: int, |
| 50 | + device: Optional[str] = None, |
| 51 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 52 | + if device is None: |
| 53 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 54 | + |
| 55 | + new_weight = new_weight.to(device) |
| 56 | + base_weight = base_weight.to(device) |
| 57 | + |
| 58 | + """ |
| 59 | + Decompose the delta weight into low-rank matrices A and B. |
| 60 | +
|
| 61 | + :param new_weight: The updated weight matrix after applying LoRA. |
| 62 | + :param base_weight: The original weight matrix before LoRA. |
| 63 | + :param reduced_rank: The rank for the low-rank decomposition. |
| 64 | + :param device: The device to perform computation on. |
| 65 | + :return: A tuple of tensors (A, B) |
| 66 | + """ |
| 67 | + delta_weight = new_weight - base_weight |
| 68 | + |
| 69 | + max_rank = min(delta_weight.shape) |
| 70 | + assert ( |
| 71 | + reduced_rank <= max_rank |
| 72 | + ), f"The specified rank ({reduced_rank}) must be smaller than or equal to the rank of the weight matrices ({max_rank})" |
| 73 | + |
| 74 | + A, B = _low_rank_decomposition(delta_weight, reduced_rank=reduced_rank) |
| 75 | + |
| 76 | + return A, B |
| 77 | + |
| 78 | + |
| 79 | +def find_all_linear_names(model: PreTrainedModel) -> List[str]: |
| 80 | + cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) |
| 81 | + |
| 82 | + names = [] |
| 83 | + for name, module in model.named_modules(): |
| 84 | + if ( |
| 85 | + isinstance(module, cls) |
| 86 | + or "Linear" in module.__class__.__name__ |
| 87 | + and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) |
| 88 | + ): |
| 89 | + names.append(name) |
| 90 | + |
| 91 | + return names |
| 92 | + |
| 93 | + |
| 94 | +def get_linear_module_names(model_id: str) -> List[str]: |
| 95 | + model = AutoModelForCausalLM.from_pretrained( |
| 96 | + model_id, state_dict={}, device_map="meta" |
| 97 | + ) # avoid loading weights as we won't need them |
| 98 | + linear_module_names = find_all_linear_names(model) |
| 99 | + |
| 100 | + return linear_module_names |
| 101 | + |
| 102 | + |
| 103 | +def create_peft_config( |
| 104 | + base_model_name_or_path: str, rank: int, alpha: int, target_modules: List[str] |
| 105 | +) -> Dict[str, Any]: |
| 106 | + return { |
| 107 | + "alpha_pattern": {}, |
| 108 | + "auto_mapping": None, |
| 109 | + "base_model_name_or_path": base_model_name_or_path, |
| 110 | + "bias": "none", |
| 111 | + "fan_in_fan_out": False, |
| 112 | + "inference_mode": True, |
| 113 | + "init_lora_weights": True, |
| 114 | + "layers_pattern": None, |
| 115 | + "layers_to_transform": None, |
| 116 | + "loftq_config": {}, |
| 117 | + "lora_alpha": alpha, |
| 118 | + "lora_dropout": 0, |
| 119 | + "megatron_config": None, |
| 120 | + "megatron_core": "megatron.core", |
| 121 | + "modules_to_save": None, |
| 122 | + "peft_type": "LORA", |
| 123 | + "r": rank, |
| 124 | + "rank_pattern": {}, |
| 125 | + "revision": None, |
| 126 | + "target_modules": target_modules, |
| 127 | + "task_type": "CAUSAL_LM", |
| 128 | + "use_rslora": False, |
| 129 | + } |
| 130 | + |
| 131 | + |
| 132 | +def reconstruct_invocation(args): |
| 133 | + """ |
| 134 | + Reconstructs the command-line invocation string based on the given arguments stored in a dictionary. |
| 135 | +
|
| 136 | + Parameters: |
| 137 | + - args: A dictionary containing the command arguments with keys matching the parameter names. |
| 138 | + Expected keys are 'base_model', 'finetuned_model', 'out_path', 'no_lazy_unpickle', 'desired_rank', 'model_name' and 'device'. |
| 139 | +
|
| 140 | + Returns: |
| 141 | + - The reconstructed command-line invocation string. |
| 142 | + """ |
| 143 | + # Provide a default value for out_path if it's not in the dictionary |
| 144 | + out_path = args.get("out_path", "OUTPUT_PATH") |
| 145 | + |
| 146 | + invocation = f"mergekit-extract-lora {args['base_model']} {args['finetuned_model']} {out_path}" |
| 147 | + if args.get("no_lazy_unpickle"): |
| 148 | + invocation += " --no-lazy-unpickle" |
| 149 | + invocation += f" --rank={args['desired_rank']}" |
| 150 | + if args.get("model_name"): |
| 151 | + invocation += f" --model_name={args['model_name']}" |
| 152 | + if args.get("device"): |
| 153 | + invocation += f" --device={args['device']}" |
| 154 | + |
| 155 | + return invocation |
| 156 | + |
| 157 | + |
| 158 | +@click.command("mergekit-extract-lora") |
| 159 | +@click.argument("finetuned_model", type=str) |
| 160 | +@click.argument("base_model", type=str) |
| 161 | +@click.argument("out_path", type=click.Path()) |
| 162 | +@click.option( |
| 163 | + "--no-lazy-unpickle", |
| 164 | + is_flag=True, |
| 165 | + help="Disable lazy unpickler (more stable, higher memory usage)", |
| 166 | +) |
| 167 | +@click.option( |
| 168 | + "--rank", |
| 169 | + "desired_rank", |
| 170 | + type=int, |
| 171 | + default=32, |
| 172 | + help="Rank for the low-rank decomposition", |
| 173 | +) |
| 174 | +@click.option( |
| 175 | + "--model_name", |
| 176 | + type=str, |
| 177 | + default=None, |
| 178 | + help="Name of the resulting model (shown in the model card)", |
| 179 | +) |
| 180 | +@click.option( |
| 181 | + "--device", |
| 182 | + type=str, |
| 183 | + default=None, |
| 184 | + help="PyTorch device to perform SVD computation on", |
| 185 | +) |
| 186 | +def main( |
| 187 | + finetuned_model: str, |
| 188 | + base_model: str, |
| 189 | + out_path: str, |
| 190 | + no_lazy_unpickle: bool, |
| 191 | + desired_rank: int, |
| 192 | + model_name: str, |
| 193 | + device: str, |
| 194 | +) -> None: |
| 195 | + """ |
| 196 | + Decomposes delta weights between a base model and a finetuned model, saving a PEFT model to the specified output path. |
| 197 | +
|
| 198 | + \b |
| 199 | + Arguments: |
| 200 | + FINETUNED_MODEL - the model ID or path to use as the PEFT extraction target model. |
| 201 | + BASE_MODEL - the model ID or path to use as the base model. |
| 202 | + OUT_PATH - the output path where the PEFT model will be saved. |
| 203 | + """ |
| 204 | + |
| 205 | + invocation_args = { |
| 206 | + "base_model": base_model, |
| 207 | + "finetuned_model": finetuned_model, |
| 208 | + "desired_rank": desired_rank, |
| 209 | + "device": device, |
| 210 | + "out_path": out_path, |
| 211 | + "model_name": model_name, |
| 212 | + "no_lazy_unpickle": no_lazy_unpickle, |
| 213 | + } |
| 214 | + |
| 215 | + os.makedirs(out_path, exist_ok=True) |
| 216 | + |
| 217 | + base_model_ref = ModelReference.parse(base_model) |
| 218 | + finetuned_model_ref = ModelReference.parse(finetuned_model) |
| 219 | + |
| 220 | + base_model_config = AutoConfig.from_pretrained(base_model_ref.model.path) |
| 221 | + |
| 222 | + linear_module_names = get_linear_module_names(base_model_ref.model.path) |
| 223 | + finetuned_model_linear_module_names = get_linear_module_names( |
| 224 | + finetuned_model_ref.model.path |
| 225 | + ) |
| 226 | + |
| 227 | + assert set(linear_module_names) == set( |
| 228 | + finetuned_model_linear_module_names |
| 229 | + ), "Model architecture mismatch" |
| 230 | + |
| 231 | + base_loader = LazyTensorLoader( |
| 232 | + base_model_ref.tensor_index(), lazy_unpickle=(not no_lazy_unpickle) |
| 233 | + ) |
| 234 | + finetuned_loader = LazyTensorLoader( |
| 235 | + finetuned_model_ref.tensor_index(), lazy_unpickle=(not no_lazy_unpickle) |
| 236 | + ) |
| 237 | + |
| 238 | + lora_weights = {} |
| 239 | + for layer_name in tqdm(linear_module_names): |
| 240 | + base_weight = base_loader.get_tensor(f"{layer_name}.weight") |
| 241 | + finetuned_weight = finetuned_loader.get_tensor(f"{layer_name}.weight") |
| 242 | + |
| 243 | + lora_A, lora_B = decompose_delta_weight( |
| 244 | + finetuned_weight, base_weight, desired_rank, device=device |
| 245 | + ) |
| 246 | + |
| 247 | + lora_weights[f"base_model.model.{layer_name}.lora_A.weight"] = lora_A.to( |
| 248 | + "cpu" |
| 249 | + ).contiguous() |
| 250 | + lora_weights[f"base_model.model.{layer_name}.lora_B.weight"] = lora_B.to( |
| 251 | + "cpu" |
| 252 | + ).contiguous() |
| 253 | + |
| 254 | + lora_config = create_peft_config( |
| 255 | + base_model_name_or_path=base_model_ref.model.path, |
| 256 | + alpha=desired_rank, # Setting the alpha to the reduced rank value as `peft` will scale the LoRA weights by alpha/r when applying the adapter |
| 257 | + rank=desired_rank, |
| 258 | + target_modules=list( |
| 259 | + set([module_name.split(".")[-1] for module_name in linear_module_names]) |
| 260 | + ), |
| 261 | + ) |
| 262 | + |
| 263 | + with open(os.path.join(out_path, "adapter_config.json"), "w") as f: |
| 264 | + json.dump(lora_config, f, indent=2) |
| 265 | + |
| 266 | + save_file(lora_weights, os.path.join(out_path, "adapter_model.safetensors")) |
| 267 | + |
| 268 | + invocation_args.pop("out_path") # don't include out_path for privacy |
| 269 | + invocation = reconstruct_invocation(invocation_args) |
| 270 | + |
| 271 | + card_md = generate_card_lora( |
| 272 | + base_model_ref=base_model_ref, |
| 273 | + finetuned_model_ref=finetuned_model_ref, |
| 274 | + invocation=invocation, |
| 275 | + name=model_name, |
| 276 | + ) |
| 277 | + |
| 278 | + with open(os.path.join(out_path, "README.md"), "w", encoding="utf-8") as fp: |
| 279 | + fp.write(card_md) |
| 280 | + |
| 281 | + logging.info(f"PEFT LoRA adapters saved to {out_path}") |
| 282 | + |
| 283 | + |
| 284 | +if __name__ == "__main__": |
| 285 | + main() |
0 commit comments