-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat:merge-lora iterate through bins without loading #3095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
ac27f1d
d0c0116
95f224c
4ccf5ae
f094ea2
fe157bd
d63de30
4b2fc64
1c1c3ab
d01bb1b
6cbc74b
f94415f
86e86a7
3384fc5
d2ce1ab
6e0617a
dd24286
d660e66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| """ | ||
| Memory-efficient LoRA merging implementation inspired by qlora-pipe. | ||
| Processes model shards individually without loading the full model into memory. | ||
| """ | ||
|
|
||
| import os | ||
| import shutil | ||
| from pathlib import Path | ||
| from typing import Dict, Optional, Union | ||
|
|
||
| import safetensors.torch | ||
| import torch | ||
| from peft import LoraConfig | ||
| from tqdm import tqdm | ||
|
|
||
| from axolotl.utils.logging import get_logger | ||
|
|
||
| LOG = get_logger(__name__) | ||
|
|
||
|
|
||
| def find_lora_weights( | ||
| lora_state: Dict[str, torch.Tensor], key: str | ||
| ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: | ||
| """ | ||
| Find corresponding LoRA A and B weights for a given key. | ||
| """ | ||
| clean_key = key.rstrip(".weight") | ||
|
|
||
| lora_a = None | ||
| lora_b = None | ||
|
|
||
| for lora_key, lora_weight in lora_state.items(): | ||
| if lora_key.endswith(f"{clean_key}.lora_A.weight"): | ||
| lora_a = lora_weight | ||
| elif lora_key.endswith(f"{clean_key}.lora_B.weight"): | ||
| lora_b = lora_weight | ||
|
|
||
| if lora_a is not None and lora_b is not None: | ||
| return lora_a, lora_b | ||
| return None, None | ||
coderabbitai[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def get_model_shards(model_path: Path) -> list[Path]: | ||
| """Find all model shards in the given path.""" | ||
| shards = list[Path]() | ||
|
|
||
| patterns = ["model*.safetensors", "pytorch_model*.bin"] | ||
|
|
||
| for pattern in patterns: | ||
| shards.extend(model_path.glob(pattern)) | ||
| if shards: | ||
| break | ||
|
|
||
| return sorted(shards) | ||
|
|
||
|
|
||
| def copy_non_model_files( | ||
| input_path: Path, output_path: Path, model_shards: list[Path] | ||
| ) -> None: | ||
| """ | ||
| Copy all non-model files to the output directory. | ||
|
|
||
| Args: | ||
| input_path: Source directory | ||
| output_path: Destination directory | ||
| model_shards: List of model shard files to skip | ||
| """ | ||
| LOG.info("Copying non-model files to output directory...") | ||
|
|
||
| shard_names = {shard.name for shard in model_shards} | ||
|
|
||
| for filepath in input_path.glob("*"): | ||
| if filepath.is_dir(): | ||
| continue | ||
| if filepath.name in shard_names: | ||
| continue | ||
| if filepath.suffix == ".gguf": | ||
| continue | ||
| if filepath.name.startswith("model") and filepath.suffix == ".safetensors": | ||
| continue | ||
|
|
||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| LOG.debug(f"Copying {filepath.name} to output") | ||
| shutil.copy(filepath, output_path) | ||
|
|
||
|
|
||
| def merge_lora_sharded_efficient( | ||
| base_model_path: Union[str, Path], | ||
| lora_adapter_path: Union[str, Path], | ||
| output_path: Union[str, Path], | ||
| device: str = "cuda", | ||
| safe_tensors: bool = True, | ||
| ) -> None: | ||
| """ | ||
| Memory-efficient LoRA merging that processes shards individually | ||
| without loading the full model into memory. | ||
| """ | ||
|
Comment on lines
127
to
137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Safety: prevent in-place overwrite of source directory If output_path = Path(output_path)
@@
- os.makedirs(output_path, exist_ok=True)
+ if output_path.resolve() == base_model_path.resolve():
+ raise ValueError("output_path must differ from base_model_path to avoid overwriting source shards")
+ os.makedirs(output_path, exist_ok=True)Also applies to: 101-106 🤖 Prompt for AI Agents |
||
| base_model_path = Path(base_model_path) | ||
| lora_adapter_path = Path(lora_adapter_path) | ||
| output_path = Path(output_path) | ||
|
|
||
| if "/" in str(base_model_path) and not base_model_path.exists(): | ||
| from huggingface_hub import snapshot_download | ||
|
||
|
|
||
| base_model_path = Path(snapshot_download(str(base_model_path))) | ||
|
|
||
| os.makedirs(output_path, exist_ok=True) | ||
|
|
||
| config_file = lora_adapter_path / "adapter_config.json" | ||
| if not config_file.exists(): | ||
| raise FileNotFoundError(f"LoRA config not found: {config_file}") | ||
|
|
||
| lora_config = LoraConfig.from_json_file(config_file) | ||
| scale = lora_config["lora_alpha"] / lora_config["r"] | ||
|
|
||
| LOG.debug(f"LoRA scale factor: {scale}") | ||
|
|
||
| lora_file = lora_adapter_path / "adapter_model.safetensors" | ||
| if not lora_file.exists(): | ||
| lora_file = lora_adapter_path / "adapter_model.bin" | ||
| if not lora_file.exists(): | ||
| raise FileNotFoundError( | ||
| f"LoRA adapter weights not found in {lora_adapter_path}" | ||
| ) | ||
|
|
||
| LOG.debug(f"Loading LoRA weights from {lora_file}") | ||
|
|
||
| if lora_file.suffix == ".safetensors": | ||
| lora_state = safetensors.torch.load_file(lora_file) | ||
| else: | ||
| lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) | ||
|
|
||
| if device != "cpu": | ||
| LOG.debug(f"Moving LoRA weights to {device}") | ||
| for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"): | ||
| lora_state[key] = value.to(device) | ||
|
|
||
| model_shards = get_model_shards(base_model_path) | ||
| if not model_shards: | ||
| raise FileNotFoundError(f"No model shards found in {base_model_path}") | ||
|
|
||
| LOG.debug(f"Found {len(model_shards)} model shards") | ||
| copy_non_model_files(base_model_path, output_path, model_shards) | ||
|
|
||
| merged_count = 0 | ||
| total_tensors = 0 | ||
|
|
||
| for shard_path in tqdm(model_shards, desc="Merging shards"): | ||
| merged_tensors = {} | ||
| metadata = {} | ||
|
|
||
| if shard_path.suffix == ".safetensors": | ||
| with safetensors.safe_open(shard_path, framework="pt", device=device) as f: | ||
| if hasattr(f, "metadata") and f.metadata(): | ||
| metadata = f.metadata() | ||
|
|
||
| for key in f.keys(): | ||
| total_tensors += 1 | ||
| tensor = f.get_tensor(key) | ||
| lora_a, lora_b = find_lora_weights(lora_state, key) | ||
|
|
||
| if lora_a is not None and lora_b is not None: | ||
| merged_count += 1 | ||
| LOG.debug( | ||
| f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}" | ||
| ) | ||
|
|
||
| original_dtype = tensor.dtype | ||
| tensor_fp32 = tensor.to(torch.float32) | ||
|
|
||
| delta = scale * ( | ||
| lora_b.to(torch.float32) @ lora_a.to(torch.float32) | ||
| ) | ||
|
|
||
| merged_tensor = (tensor_fp32 + delta).to(original_dtype) | ||
| merged_tensors[key] = merged_tensor | ||
|
||
| else: | ||
| merged_tensors[key] = tensor | ||
| else: | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| state_dict = torch.load(shard_path, map_location=device) # nosec B614: loading trusted model weights | ||
| for key, tensor in state_dict.items(): | ||
| total_tensors += 1 | ||
| lora_a, lora_b = find_lora_weights(lora_state, key) | ||
|
|
||
| if lora_a is not None and lora_b is not None: | ||
| merged_count += 1 | ||
| original_dtype = tensor.dtype | ||
| tensor_fp32 = tensor.to(torch.float32) | ||
| delta = scale * ( | ||
| lora_b.to(torch.float32) @ lora_a.to(torch.float32) | ||
| ) | ||
| merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype) | ||
| else: | ||
| merged_tensors[key] = tensor | ||
|
|
||
| output_shard_path = output_path / shard_path.name | ||
| if safe_tensors and shard_path.suffix == ".safetensors": | ||
| safetensors.torch.save_file( | ||
| merged_tensors, output_shard_path, metadata=metadata | ||
| ) | ||
| else: | ||
| if safe_tensors: | ||
| output_shard_path = output_shard_path.with_suffix(".safetensors") | ||
| torch.save(merged_tensors, output_shard_path) | ||
|
|
||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| del merged_tensors | ||
| if device != "cpu": | ||
| torch.cuda.empty_cache() | ||
|
|
||
| LOG.info(f"Applied LoRA to {merged_count}/{total_tensors} tensors") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| """Pydantic models for PEFT-related configuration""" | ||
|
|
||
| from typing import Any | ||
| from typing import Any, Literal | ||
|
|
||
| from pydantic import BaseModel, Field, field_validator, model_validator | ||
|
|
||
|
|
@@ -130,6 +130,12 @@ class LoraConfig(BaseModel): | |
| ) | ||
|
|
||
| merge_lora: bool | None = None | ||
| merge_method: Literal["legacy", "memory_efficient"] | None = Field( | ||
|
||
| default="memory_efficient", | ||
| json_schema_extra={ | ||
| "description": "Method to use for LoRA merging. 'memory_efficient' (default) processes shards individually to reduce memory usage, 'legacy' loads the full model into memory." | ||
| }, | ||
| ) | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.