-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat:merge-lora iterate through bins without loading #3095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 12 commits
ac27f1d
d0c0116
95f224c
4ccf5ae
f094ea2
fe157bd
d63de30
4b2fc64
1c1c3ab
d01bb1b
6cbc74b
f94415f
86e86a7
3384fc5
d2ce1ab
6e0617a
dd24286
d660e66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,23 +4,45 @@ | |||||||||||||
| from typing import Union | ||||||||||||||
|
|
||||||||||||||
| import fire | ||||||||||||||
| import torch | ||||||||||||||
|
|
||||||||||||||
| from axolotl.cli.config import load_cfg | ||||||||||||||
| from axolotl.cli.utils import load_model_and_tokenizer | ||||||||||||||
| from axolotl.utils.dict import DictDefault | ||||||||||||||
| from axolotl.utils.logging import get_logger | ||||||||||||||
| from axolotl.utils.lora_merge_efficient import merge_lora_sharded_efficient | ||||||||||||||
|
|
||||||||||||||
| LOG = get_logger(__name__) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def do_merge_lora(*, cfg: DictDefault) -> None: | ||||||||||||||
| """ | ||||||||||||||
| Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config | ||||||||||||||
| along with the LoRA adapters to combine them into a single base model. | ||||||||||||||
| Merges LoRA adapters with base model using either memory-efficient or legacy approach. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| cfg: Dictionary mapping `axolotl` config keys to values. | ||||||||||||||
| """ | ||||||||||||||
| merge_method = ( | ||||||||||||||
| str(getattr(cfg, "merge_method", "")).strip().lower().replace("-", "_") | ||||||||||||||
| ) | ||||||||||||||
| if merge_method in {"legacy", "standard"}: | ||||||||||||||
|
||||||||||||||
| LOG.debug("Using legacy LoRA merging method...") | ||||||||||||||
| _do_merge_lora_legacy(cfg=cfg) | ||||||||||||||
| else: | ||||||||||||||
| LOG.debug("Using memory-efficient LoRA merging method...") | ||||||||||||||
| try: | ||||||||||||||
| _do_merge_lora_efficient(cfg=cfg) | ||||||||||||||
| except Exception: # pylint: disable=broad-exception-caught | ||||||||||||||
| LOG.exception("Memory-efficient merge failed; falling back to legacy.") | ||||||||||||||
| _do_merge_lora_legacy(cfg=cfg) | ||||||||||||||
|
||||||||||||||
| try: | |
| _do_merge_lora_efficient(cfg=cfg) | |
| except Exception: # pylint: disable=broad-exception-caught | |
| LOG.exception("Memory-efficient merge failed; falling back to legacy.") | |
| _do_merge_lora_legacy(cfg=cfg) | |
| _do_merge_lora_efficient(cfg=cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are unsupported combinations (you mentioned DoRA, RSLoRA), we should validate this in the pydantic model and raise an error there.
ved1beta marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,282 @@ | ||
| """ | ||
| Memory-efficient LoRA merging implementation inspired by qlora-pipe. | ||
| Processes model shards individually without loading the full model into memory. | ||
| """ | ||
|
|
||
| import gc | ||
| import os | ||
| import shutil | ||
| from pathlib import Path | ||
| from typing import Dict, Optional, Union | ||
|
|
||
| import safetensors | ||
| import safetensors.torch | ||
| import torch | ||
| from peft import LoraConfig | ||
| from tqdm import tqdm | ||
|
|
||
| from axolotl.utils.logging import get_logger | ||
|
|
||
| LOG = get_logger(__name__) | ||
|
|
||
|
|
||
| def find_lora_weights( | ||
| lora_state: Dict[str, torch.Tensor], key: str | ||
| ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: | ||
| """ | ||
| Find corresponding LoRA A and B weights for a given key. | ||
| """ | ||
| clean_key = key[:-7] if key.endswith(".weight") else key | ||
|
|
||
| a_key = f"base_model.model.{clean_key}.lora_A.weight" | ||
| b_key = f"base_model.model.{clean_key}.lora_B.weight" | ||
|
|
||
| lora_a = lora_state.get(a_key) | ||
| lora_b = lora_state.get(b_key) | ||
|
|
||
| if lora_a is not None and lora_b is not None: | ||
| return lora_a, lora_b | ||
| return None, None | ||
|
|
||
|
|
||
| def get_model_shards(model_path: Path) -> list[Path]: | ||
| """Find all model shards in the given path.""" | ||
| shards: list[Path] = [] | ||
|
|
||
| patterns = ["model*.safetensors", "pytorch_model*.bin"] | ||
|
|
||
| for pattern in patterns: | ||
| shards.extend(model_path.glob(pattern)) | ||
| if shards: | ||
| break | ||
|
|
||
| return sorted(shards) | ||
|
|
||
|
|
||
| def copy_non_model_files( | ||
| input_path: Path, output_path: Path, model_shards: list[Path] | ||
| ) -> None: | ||
| """ | ||
| Copy all non-model files to the output directory. | ||
| Args: | ||
| input_path: Source directory | ||
| output_path: Destination directory | ||
| model_shards: List of model shard files to skip | ||
| """ | ||
| LOG.info("Copying non-model files to output directory...") | ||
|
|
||
| shard_names = {shard.name for shard in model_shards} | ||
|
|
||
| for filepath in input_path.glob("*"): | ||
| if filepath.is_dir(): | ||
| continue | ||
| if filepath.name in shard_names: | ||
| continue | ||
| if ( | ||
| filepath.name.startswith("model") and filepath.suffix == ".safetensors" | ||
| ) or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin"): | ||
| continue | ||
| if filepath.suffix == ".gguf": | ||
| continue | ||
|
|
||
| LOG.debug(f"Copying {filepath.name} to output") | ||
| shutil.copy2(filepath, output_path) | ||
|
|
||
|
|
||
| def merge_lora_sharded_efficient( | ||
| base_model_path: Union[str, Path], | ||
| lora_adapter_path: Union[str, Path], | ||
| output_path: Union[str, Path], | ||
| device: str = "cpu", | ||
| safe_tensors: bool = True, | ||
| ) -> None: | ||
| """ | ||
| Memory-efficient LoRA merging that processes shards individually | ||
| without loading the full model into memory. | ||
| """ | ||
| base_model_path = Path(base_model_path) | ||
| lora_adapter_path = Path(lora_adapter_path) | ||
| output_path = Path(output_path) | ||
|
|
||
| if "/" in str(base_model_path) and not base_model_path.exists(): | ||
| from huggingface_hub import snapshot_download | ||
|
||
|
|
||
| base_model_path = Path(snapshot_download(str(base_model_path))) | ||
|
|
||
| os.makedirs(output_path, exist_ok=True) | ||
|
|
||
| config_file = lora_adapter_path / "adapter_config.json" | ||
| if not config_file.exists(): | ||
| raise FileNotFoundError(f"LoRA config not found: {config_file}") | ||
|
|
||
| lora_config_dict = LoraConfig.from_json_file(str(config_file)) | ||
| if not lora_config_dict.get("r") or lora_config_dict["r"] <= 0: | ||
| raise ValueError("LoRA config 'r' must be > 0") | ||
|
|
||
| unsupported_methods = [] | ||
|
|
||
| # Check for DoRA (Weight-Decomposed LoRA) | ||
| if lora_config_dict.get("use_dora", False): | ||
| unsupported_methods.append("DoRA (Weight-Decomposed LoRA)") | ||
|
|
||
| # Check for AdaLoRA (Adaptive LoRA) | ||
| if lora_config_dict.get("use_adalora", False): | ||
| unsupported_methods.append("AdaLoRA (Adaptive LoRA)") | ||
|
|
||
| # Check for VeRA (Vector-based Random Matrix Adaptation) | ||
| if lora_config_dict.get("use_vera", False): | ||
| unsupported_methods.append("VeRA (Vector-based Random Matrix Adaptation)") | ||
|
|
||
| # Check for other advanced LoRA variants by task_type | ||
| task_type = lora_config_dict.get("task_type", "") | ||
| if task_type and task_type not in [ | ||
| "CAUSAL_LM", | ||
| "SEQ_2_SEQ_LM", | ||
| "TOKEN_CLS", | ||
| "SEQ_CLS", | ||
| "QUESTION_ANS", | ||
| ]: | ||
| unsupported_methods.append(f"Task type: {task_type}") | ||
|
|
||
| # Check for rank adaptation patterns (AdaLoRA indicators) | ||
| if any( | ||
| key in lora_config_dict | ||
| for key in ["rank_pattern", "alpha_pattern", "target_rank"] | ||
| ): | ||
| unsupported_methods.append("AdaLoRA (rank adaptation detected)") | ||
|
|
||
| # Check for advanced initialization methods | ||
| init_lora_weights = lora_config_dict.get("init_lora_weights", "") | ||
| if init_lora_weights and init_lora_weights not in [ | ||
| "gaussian", | ||
| "loftq", | ||
| True, | ||
| False, | ||
| ]: | ||
| unsupported_methods.append(f"Advanced initialization: {init_lora_weights}") | ||
|
|
||
| if unsupported_methods: | ||
| methods_str = ", ".join(unsupported_methods) | ||
| raise NotImplementedError( | ||
| f"Memory-efficient LoRA merge only supports standard LoRA. " | ||
| f"Detected unsupported methods: {methods_str}. " | ||
| f"Please use the legacy merge method for advanced LoRA variants." | ||
| ) | ||
|
|
||
| scale = float(lora_config_dict["lora_alpha"]) / float(lora_config_dict["r"]) | ||
|
|
||
| LOG.debug(f"LoRA scale factor: {scale}") | ||
|
|
||
| lora_file = lora_adapter_path / "adapter_model.safetensors" | ||
| if not lora_file.exists(): | ||
| lora_file = lora_adapter_path / "adapter_model.bin" | ||
| if not lora_file.exists(): | ||
| raise FileNotFoundError( | ||
| f"LoRA adapter weights not found in {lora_adapter_path}" | ||
| ) | ||
|
|
||
| LOG.debug(f"Loading LoRA weights from {lora_file}") | ||
|
|
||
| if lora_file.suffix == ".safetensors": | ||
| lora_state = safetensors.torch.load_file(lora_file) | ||
| else: | ||
| try: | ||
| lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614 | ||
| except TypeError: | ||
| lora_state = torch.load(lora_file, map_location="cpu") # nosec B614 | ||
|
||
| LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge") | ||
|
|
||
| model_shards = get_model_shards(base_model_path) | ||
| if not model_shards: | ||
| raise FileNotFoundError(f"No model shards found in {base_model_path}") | ||
|
|
||
| LOG.debug(f"Found {len(model_shards)} model shards in {base_model_path}") | ||
| copy_non_model_files(base_model_path, output_path, model_shards) | ||
|
|
||
| merged_count = 0 | ||
| total_tensors = 0 | ||
|
|
||
| for shard_path in tqdm(model_shards, desc="Merging shards"): | ||
| merged_tensors = {} | ||
| metadata = {} | ||
|
|
||
| if shard_path.suffix == ".safetensors": | ||
| with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f: | ||
| if hasattr(f, "metadata") and f.metadata(): | ||
| metadata = f.metadata() | ||
|
|
||
| for key in f.keys(): | ||
| total_tensors += 1 | ||
| tensor = f.get_tensor(key) | ||
| lora_a, lora_b = find_lora_weights(lora_state, key) | ||
|
|
||
| if lora_a is not None and lora_b is not None: | ||
| merged_count += 1 | ||
| LOG.debug( | ||
| f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}" | ||
| ) | ||
|
|
||
| original_dtype = tensor.dtype | ||
| base_fp32 = tensor.to(device).to(torch.float32) | ||
| a_fp32 = lora_a.to(device).to(torch.float32) | ||
| b_fp32 = lora_b.to(device).to(torch.float32) | ||
| delta = scale * (b_fp32 @ a_fp32) | ||
| if bool( | ||
| lora_config_dict.get("fan_in_fan_out", False) | ||
| or lora_config_dict.get("lora_fan_in_fan_out", False) | ||
| ): | ||
| delta = delta.T | ||
| merged_tensors[key] = ( | ||
| (base_fp32 + delta).to(original_dtype).detach().cpu() | ||
| ) | ||
| del base_fp32, a_fp32, b_fp32, delta | ||
|
||
| else: | ||
| merged_tensors[key] = tensor.detach().cpu() | ||
| else: | ||
| state_dict = torch.load( # nosec B614: loading trusted model weights | ||
| shard_path, map_location="cpu", weights_only=True | ||
| ) | ||
| for key, tensor in state_dict.items(): | ||
| total_tensors += 1 | ||
| lora_a, lora_b = find_lora_weights(lora_state, key) | ||
|
|
||
| if lora_a is not None and lora_b is not None: | ||
| merged_count += 1 | ||
| original_dtype = tensor.dtype | ||
| base_fp32 = tensor.to(device).to(torch.float32) | ||
| a_fp32 = lora_a.to(device).to(torch.float32) | ||
| b_fp32 = lora_b.to(device).to(torch.float32) | ||
| delta = scale * (b_fp32 @ a_fp32) | ||
| if bool( | ||
| lora_config_dict.get("fan_in_fan_out", False) | ||
| or lora_config_dict.get("lora_fan_in_fan_out", False) | ||
| ): | ||
| delta = delta.T | ||
| merged_tensors[key] = ( | ||
| (base_fp32 + delta).to(original_dtype).detach().cpu() | ||
| ) | ||
| del base_fp32, a_fp32, b_fp32, delta | ||
| else: | ||
| merged_tensors[key] = tensor.detach().cpu() | ||
|
|
||
| output_shard_path = output_path / shard_path.name | ||
| merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} | ||
| if shard_path.suffix == ".safetensors": | ||
| safetensors.torch.save_file( | ||
| merged_tensors, output_shard_path, metadata=metadata | ||
| ) | ||
| else: | ||
| if safe_tensors: | ||
| LOG.warning( | ||
| "safe_tensors=True requested but input shards are .bin; preserving .bin format " | ||
| "to avoid index mismatches." | ||
| ) | ||
|
||
| torch.save(merged_tensors, output_shard_path) | ||
|
|
||
| del merged_tensors | ||
| if device != "cpu" and torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
| gc.collect() | ||
|
|
||
| LOG.info(f"Applied LoRA to {merged_count}/{total_tensors} tensors") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| """Pydantic models for PEFT-related configuration""" | ||
|
|
||
| from typing import Any | ||
| from typing import Any, Literal | ||
|
|
||
| from pydantic import BaseModel, Field, field_validator, model_validator | ||
|
|
||
|
|
@@ -140,6 +140,12 @@ class LoraConfig(BaseModel): | |
| ) | ||
|
|
||
| merge_lora: bool | None = None | ||
| merge_method: Literal["legacy", "memory_efficient"] | None = Field( | ||
|
||
| default="memory_efficient", | ||
| json_schema_extra={ | ||
| "description": "Method to use for LoRA merging. 'memory_efficient' (default) processes shards individually to reduce memory usage, 'legacy' loads the full model into memory." | ||
| }, | ||
| ) | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge_methodcan only take values:Literal["legacy", "memory_efficient"]so you don't need this string handling.