- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.2k
feat:merge-lora iterate through bins without loading #3095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
ac27f1d
              d0c0116
              95f224c
              4ccf5ae
              f094ea2
              fe157bd
              d63de30
              4b2fc64
              1c1c3ab
              d01bb1b
              6cbc74b
              f94415f
              86e86a7
              3384fc5
              d2ce1ab
              6e0617a
              dd24286
              d660e66
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,211 @@ | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Memory-efficient LoRA merging implementation inspired by qlora-pipe. | ||||||||||||||||||||||||
| Processes model shards individually without loading the full model into memory. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||
| import shutil | ||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||
| from typing import Dict, Optional, Union | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| import safetensors.torch | ||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||
| from peft import LoraConfig | ||||||||||||||||||||||||
| from tqdm import tqdm | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| from axolotl.utils.logging import get_logger | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| LOG = get_logger(__name__) | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| def find_lora_weights( | ||||||||||||||||||||||||
| lora_state: Dict[str, torch.Tensor], key: str | ||||||||||||||||||||||||
| ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Find corresponding LoRA A and B weights for a given key. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| clean_key = key.rstrip(".weight") | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| lora_a = None | ||||||||||||||||||||||||
| lora_b = None | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| for lora_key, lora_weight in lora_state.items(): | ||||||||||||||||||||||||
| if lora_key.endswith(f"{clean_key}.lora_A.weight"): | ||||||||||||||||||||||||
| lora_a = lora_weight | ||||||||||||||||||||||||
| elif lora_key.endswith(f"{clean_key}.lora_B.weight"): | ||||||||||||||||||||||||
| lora_b = lora_weight | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| if lora_a is not None and lora_b is not None: | ||||||||||||||||||||||||
| return lora_a, lora_b | ||||||||||||||||||||||||
| return None, None | ||||||||||||||||||||||||
|         
                  coderabbitai[bot] marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| def get_model_shards(model_path: Path) -> list[Path]: | ||||||||||||||||||||||||
| """Find all model shards in the given path.""" | ||||||||||||||||||||||||
| shards = list[Path]() | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| patterns = ["model*.safetensors", "pytorch_model*.bin"] | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| for pattern in patterns: | ||||||||||||||||||||||||
| shards.extend(model_path.glob(pattern)) | ||||||||||||||||||||||||
| if shards: | ||||||||||||||||||||||||
| break | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| return sorted(shards) | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| def copy_non_model_files( | ||||||||||||||||||||||||
| input_path: Path, output_path: Path, model_shards: list[Path] | ||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Copy all non-model files to the output directory. | ||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||
| input_path: Source directory | ||||||||||||||||||||||||
| output_path: Destination directory | ||||||||||||||||||||||||
| model_shards: List of model shard files to skip | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| LOG.info("Copying non-model files to output directory...") | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| shard_names = {shard.name for shard in model_shards} | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| for filepath in input_path.glob("*"): | ||||||||||||||||||||||||
| if filepath.is_dir(): | ||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||
| if filepath.name in shard_names: | ||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||
| if filepath.suffix == ".gguf": | ||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||
| if filepath.name.startswith("model") and filepath.suffix == ".safetensors": | ||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
|         
                  coderabbitai[bot] marked this conversation as resolved.
              Show resolved
            Hide resolved | ||||||||||||||||||||||||
| LOG.debug(f"Copying {filepath.name} to output") | ||||||||||||||||||||||||
| shutil.copy(filepath, output_path) | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| def merge_lora_sharded_efficient( | ||||||||||||||||||||||||
| base_model_path: Union[str, Path], | ||||||||||||||||||||||||
| lora_adapter_path: Union[str, Path], | ||||||||||||||||||||||||
| output_path: Union[str, Path], | ||||||||||||||||||||||||
| device: str = "cuda", | ||||||||||||||||||||||||
| safe_tensors: bool = True, | ||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Memory-efficient LoRA merging that processes shards individually | ||||||||||||||||||||||||
| without loading the full model into memory. | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| 
      Comment on lines
    
      127
     to 
      137
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Safety: prevent in-place overwrite of source directory If       output_path = Path(output_path)
@@
-    os.makedirs(output_path, exist_ok=True)
+    if output_path.resolve() == base_model_path.resolve():
+        raise ValueError("output_path must differ from base_model_path to avoid overwriting source shards")
+    os.makedirs(output_path, exist_ok=True)Also applies to: 101-106 🤖 Prompt for AI Agents | ||||||||||||||||||||||||
| base_model_path = Path(base_model_path) | ||||||||||||||||||||||||
| lora_adapter_path = Path(lora_adapter_path) | ||||||||||||||||||||||||
| output_path = Path(output_path) | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| if "/" in str(base_model_path) and not base_model_path.exists(): | ||||||||||||||||||||||||
| from huggingface_hub import snapshot_download | ||||||||||||||||||||||||
|          | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| base_model_path = Path(snapshot_download(str(base_model_path))) | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| os.makedirs(output_path, exist_ok=True) | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| config_file = lora_adapter_path / "adapter_config.json" | ||||||||||||||||||||||||
| if not config_file.exists(): | ||||||||||||||||||||||||
| raise FileNotFoundError(f"LoRA config not found: {config_file}") | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| lora_config = LoraConfig.from_json_file(config_file) | ||||||||||||||||||||||||
| scale = lora_config["lora_alpha"] / lora_config["r"] | ||||||||||||||||||||||||
|  | ||||||||||||||||||||||||
| LOG.info(f"LoRA scale factor: {scale}") | ||||||||||||||||||||||||
|         
                  ved1beta marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved          | ||||||||||||||||||||||||
| lora_config = LoraConfig.from_json_file(config_file) | |
| scale = lora_config["lora_alpha"] / lora_config["r"] | |
| LOG.info(f"LoRA scale factor: {scale}") | |
| lora_config = LoraConfig.from_json_file(config_file) | |
| # Ensure 'r' is present and non-zero to avoid division by zero | |
| if not getattr(lora_config, "r", None): | |
| raise ValueError("LoRA config 'r' must be > 0") | |
| scale = float(lora_config.lora_alpha) / float(lora_config.r) | |
| LOG.info(f"LoRA scale factor: {scale}") | 
🤖 Prompt for AI Agents
In src/axolotl/utils/lora_merge_efficient.py around lines 112 to 115, the code
treats LoraConfig as a dict (lora_config["lora_alpha"]) which will raise since
LoraConfig.from_json_file returns an object; change to use attribute access
(e.g., lora_config.lora_alpha and lora_config.r) and add a guard to prevent
division by zero (check if lora_config.r is truthy/greater than zero before
dividing, otherwise set scale to a safe default or raise a clear error), then
log the computed scale.
        
          
              
                  ved1beta marked this conversation as resolved.
              
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
              
                  coderabbitai[bot] marked this conversation as resolved.
              
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
              
                  ved1beta marked this conversation as resolved.
              
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
              
                Outdated
          
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Support fan_in_fan_out orientation from config
If lora_config.fan_in_fan_out (or lora_fan_in_fan_out) is enabled, a transpose is required to match PEFT’s application. Current code ignores this and may produce incorrect merges for such adapters.
Minimal patch sketch inside both safetensors and .bin branches:
-                        delta = scale * (b_fp32 @ a_fp32)
+                        delta = scale * (b_fp32 @ a_fp32)
+                        if getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False):
+                            delta = delta.TPlease confirm the exact field name used by your PEFT version and add tests for a fan-in-fan-out adapter.
Also applies to: 188-194
🏁 Script executed:
#!/bin/bash
set -e
# Print context around the LoRA merge logic (lines ~150–220)
sed -n '150,220p' src/axolotl/utils/lora_merge_efficient.py
# Locate any definitions or imports of LoraConfig
rg -n "class.*LoraConfig" -g src
rg -n "import.*LoraConfig" -g src
# Check for existing handling of fan_in_fan_out flags
rg -n "fan_in_fan_out" -g src
rg -n "lora_fan_in_fan_out" -g src || trueLength of output: 2842
🏁 Script executed:
#!/bin/bash
set -e
# Print file header and initial context to locate function signature and imports
sed -n '1,200p' src/axolotl/utils/lora_merge_efficient.py
# Search for any mention of fan_in_fan_out in the entire repository
rg -n "fan_in_fan_out" .
# Search for LoraConfig class or imports across the codebase
rg -n "LoraConfig" .
# Search for usage of lora_config within this file
rg -n "lora_config" -n src/axolotl/utils/lora_merge_efficient.pyLength of output: 8450
Support LoRA fan_in_fan_out orientation in merge
The current merge logic in src/axolotl/utils/lora_merge_efficient.py always computes
delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))without accounting for the fan_in_fan_out flag in the PEFT config, which will result in incorrect merges when adapters were trained with fan_in_fan_out=True.
Please apply the following change in both the .safetensors branch (around lines 170–175) and the .bin branch (around lines 188–194):
-   delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))
+   delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))
+   if lora_config.fan_in_fan_out:
+       delta = delta.T• Locations to update:
- safetensors loop (after line 170)
- torch.load loop (after line 188)
• Add a unit test with a LoRA adapter configured as fan_in_fan_out=True to verify the transpose is applied correctly.
🤖 Prompt for AI Agents
In src/axolotl/utils/lora_merge_efficient.py around lines 170–175 (safetensors
branch) and around lines 188–194 (torch.load/.bin branch), the merge always
computes delta as scale * (lora_b @ lora_a) and ignores the PEFT config flag
fan_in_fan_out; update both locations to check the adapter config and, when
fan_in_fan_out is True, transpose lora_a and lora_b appropriately (e.g., swap or
transpose operands so multiplication reflects the trained orientation) before
computing delta, then cast back to original dtype as now; also add a unit test
that loads/creates a LoRA adapter with fan_in_fan_out=True, runs the merge, and
asserts the merged tensor matches the expected result when the transpose branch
is applied.
        
          
              
                  coderabbitai[bot] marked this conversation as resolved.
              
          
            Show resolved
            Hide resolved
        
              
          
              
                  coderabbitai[bot] marked this conversation as resolved.
              
              
                Outdated
          
            Show resolved
            Hide resolved
        
              
          
              
                  coderabbitai[bot] marked this conversation as resolved.
              
          
            Show resolved
            Hide resolved
        
      | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| """Pydantic models for PEFT-related configuration""" | ||
|  | ||
| from typing import Any | ||
| from typing import Any, Literal | ||
|  | ||
| from pydantic import BaseModel, Field, field_validator, model_validator | ||
|  | ||
|  | @@ -130,6 +130,12 @@ class LoraConfig(BaseModel): | |
| ) | ||
|  | ||
| merge_lora: bool | None = None | ||
| merge_method: Literal["legacy", "memory_efficient"] | None = Field( | ||
|          | ||
| default="memory_efficient", | ||
| json_schema_extra={ | ||
| "description": "Method to use for LoRA merging. 'memory_efficient' (default) processes shards individually to reduce memory usage, 'legacy' loads the full model into memory." | ||
| }, | ||
| ) | ||
|  | ||
| @model_validator(mode="before") | ||
| @classmethod | ||
|  | ||
Uh oh!
There was an error while loading. Please reload this page.