Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions src/axolotl/cli/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,31 @@
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.lora_merge_efficient import merge_lora_sharded_efficient

LOG = get_logger(__name__)


def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
along with the LoRA adapters to combine them into a single base model.
Merges LoRA adapters with base model using either standard or memory-efficient approach.

Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
merge_method = getattr(cfg, "merge_method", "standard")
if merge_method == "memory_efficient":
_do_merge_lora_efficient(cfg=cfg)
else:
_do_merge_lora_standard(cfg=cfg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's be opinionated and prefer the memory_efficient merge. the old "standard" version can be a manual fallback if users have issues. You could also do something like wrap the _do_merge_lora_efficient in a try/catch and if it fails, provide a hint to the user to use the standard option

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename standard to legacy or something else since it's no longer axolotl's "standard"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to consider is that this only works for standard LoRA, and not other advanced methods like DoRA



def _do_merge_lora_standard(*, cfg: DictDefault) -> None:
"""
Standard LoRA merging using `merge_and_unload`.
Loads the full model into memory before merging.
"""
LOG.info("Using standard LoRA merging method...")
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True

Expand Down Expand Up @@ -49,6 +62,27 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))


def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
"""
Memory-efficient LoRA merging using shard-by-shard processing.
Does not load the full model into memory.
"""
LOG.info("Using memory-efficient LoRA merging method...")

output_path = Path(cfg.output_dir) / "merged"
safe_tensors = getattr(cfg, "save_safetensors", True)

# Perform memory-efficient merge
merge_lora_sharded_efficient(
base_model_path=cfg.base_model,
lora_adapter_path=cfg.lora_model_dir,
output_path=output_path,
safe_tensors=safe_tensors,
)

LOG.info("Memory-efficient LoRA merge completed successfully!")


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various
Expand Down Expand Up @@ -80,7 +114,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
raise ValueError(
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`"
)

do_merge_lora(cfg=parsed_cfg)
Expand Down
214 changes: 214 additions & 0 deletions src/axolotl/utils/lora_merge_efficient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
"""
Memory-efficient LoRA merging implementation inspired by qlora-pipe.
Processes model shards individually without loading the full model into memory.
"""

import os
import re
import shutil
from pathlib import Path
from typing import Dict, Optional, Union

import safetensors.torch
import torch
from peft import LoraConfig
from tqdm import tqdm

from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


def find_lora_weights(
lora_state: Dict[str, torch.Tensor], key: str
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Find corresponding LoRA A and B weights for a given key.
"""
clean_key = key.strip(".weight")
clean_key = re.sub(r"^(base_model\.model\.|language_model\.)", "", clean_key)

lora_a = None
lora_b = None

for lora_key, lora_weight in lora_state.items():
if clean_key in lora_key:
if "lora_A" in lora_key:
lora_a = lora_weight
elif "lora_B" in lora_key:
lora_b = lora_weight

if lora_a is not None and lora_b is not None:
return lora_a, lora_b
return None, None


def get_model_shards(model_path: Path) -> list[Path]:
"""Find all model shards in the given path."""
shards = list[Path]()

patterns = ["model*.safetensors", "model*.bin", "pytorch_model*.bin"]

for pattern in patterns:
shards.extend(model_path.glob(pattern))
if shards:
break

return sorted(shards)


def copy_non_model_files(
input_path: Path, output_path: Path, model_shards: list[Path]
) -> None:
"""
Copy all non-model files to the output directory.

Args:
input_path: Source directory
output_path: Destination directory
model_shards: List of model shard files to skip
"""
LOG.info("Copying non-model files to output directory...")

shard_names = {shard.name for shard in model_shards}

for filepath in input_path.glob("*"):
if filepath.is_dir():
continue
if filepath.name in shard_names:
continue
if filepath.suffix == ".gguf":
continue
if filepath.name.startswith("model") and filepath.suffix == ".safetensors":
continue

LOG.debug(f"Copying {filepath.name} to output")
shutil.copy(filepath, output_path)


def merge_lora_sharded_efficient(
base_model_path: Union[str, Path],
lora_adapter_path: Union[str, Path],
output_path: Union[str, Path],
device: str = "cuda",
safe_tensors: bool = True,
) -> None:
"""
Memory-efficient LoRA merging that processes shards individually
without loading the full model into memory.
"""
Comment on lines 127 to 137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Safety: prevent in-place overwrite of source directory

If output_path equals base_model_path, you risk clobbering source files. Add a guard and fail fast.

     output_path = Path(output_path)
@@
-    os.makedirs(output_path, exist_ok=True)
+    if output_path.resolve() == base_model_path.resolve():
+        raise ValueError("output_path must differ from base_model_path to avoid overwriting source shards")
+    os.makedirs(output_path, exist_ok=True)

Also applies to: 101-106

🤖 Prompt for AI Agents
In src/axolotl/utils/lora_merge_efficient.py around lines 86-96 (and similarly
at 101-106), add a fail-fast guard that prevents output_path from being the same
as base_model_path (and also disallow matching lora_adapter_path) to avoid
in-place overwrites; implement by converting inputs to pathlib.Path and
comparing resolved absolute paths (Path(...).resolve()) and if any matches raise
a clear ValueError (or SystemExit) with a message like "output_path must be
different from base_model_path/lora_adapter_path" before performing any file
operations.

base_model_path = Path(base_model_path)
lora_adapter_path = Path(lora_adapter_path)
output_path = Path(output_path)

if "/" in str(base_model_path) and not base_model_path.exists():
from huggingface_hub import snapshot_download
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be a toplevel import


base_model_path = Path(snapshot_download(str(base_model_path)))

os.makedirs(output_path, exist_ok=True)

config_file = lora_adapter_path / "adapter_config.json"
if not config_file.exists():
raise FileNotFoundError(f"LoRA config not found: {config_file}")

lora_config = LoraConfig.from_json_file(config_file)
scale = lora_config.lora_alpha / lora_config.r

LOG.info(f"LoRA scale factor: {scale}")

lora_file = lora_adapter_path / "adapter_model.safetensors"
if not lora_file.exists():
lora_file = lora_adapter_path / "adapter_model.bin"
if not lora_file.exists():
raise FileNotFoundError(
f"LoRA adapter weights not found in {lora_adapter_path}"
)

LOG.debug(f"Loading LoRA weights from {lora_file}")

if lora_file.suffix == ".safetensors":
lora_state = safetensors.torch.load_file(lora_file)
else:
lora_state = torch.load(lora_file, map_location="cpu", weights_only=True)

if device != "cpu":
LOG.info(f"Moving LoRA weights to {device}")
for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"):
lora_state[key] = value.to(device)

model_shards = get_model_shards(base_model_path)
if not model_shards:
raise FileNotFoundError(f"No model shards found in {base_model_path}")

LOG.info(f"Found {len(model_shards)} model shards")
copy_non_model_files(base_model_path, output_path, model_shards)

merged_count = 0
total_tensors = 0

for shard_path in tqdm(model_shards, desc="Merging shards"):
merged_tensors = {}
metadata = {}

if shard_path.suffix == ".safetensors":
with safetensors.safe_open(shard_path, framework="pt", device=device) as f:
if hasattr(f, "metadata") and f.metadata():
metadata = f.metadata()

for key in f.keys():
total_tensors += 1
tensor = f.get_tensor(key)
lora_a, lora_b = find_lora_weights(lora_state, key)

if lora_a is not None and lora_b is not None:
merged_count += 1
LOG.debug(
f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}"
)

original_dtype = tensor.dtype
tensor_fp32 = tensor.to(torch.float32)

delta = scale * (
lora_b.to(torch.float32) @ lora_a.to(torch.float32)
)

merged_tensor = (tensor_fp32 + delta).to(original_dtype)
merged_tensors[key] = merged_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Support fan_in_fan_out orientation from config

If lora_config.fan_in_fan_out (or lora_fan_in_fan_out) is enabled, a transpose is required to match PEFT’s application. Current code ignores this and may produce incorrect merges for such adapters.

Minimal patch sketch inside both safetensors and .bin branches:

-                        delta = scale * (b_fp32 @ a_fp32)
+                        delta = scale * (b_fp32 @ a_fp32)
+                        if getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False):
+                            delta = delta.T

Please confirm the exact field name used by your PEFT version and add tests for a fan-in-fan-out adapter.

Also applies to: 188-194


🏁 Script executed:

#!/bin/bash
set -e

# Print context around the LoRA merge logic (lines ~150–220)
sed -n '150,220p' src/axolotl/utils/lora_merge_efficient.py

# Locate any definitions or imports of LoraConfig
rg -n "class.*LoraConfig" -g src
rg -n "import.*LoraConfig" -g src

# Check for existing handling of fan_in_fan_out flags
rg -n "fan_in_fan_out" -g src
rg -n "lora_fan_in_fan_out" -g src || true

Length of output: 2842


🏁 Script executed:

#!/bin/bash
set -e

# Print file header and initial context to locate function signature and imports
sed -n '1,200p' src/axolotl/utils/lora_merge_efficient.py

# Search for any mention of fan_in_fan_out in the entire repository
rg -n "fan_in_fan_out" .

# Search for LoraConfig class or imports across the codebase
rg -n "LoraConfig" .

# Search for usage of lora_config within this file
rg -n "lora_config" -n src/axolotl/utils/lora_merge_efficient.py

Length of output: 8450


Support LoRA fan_in_fan_out orientation in merge

The current merge logic in src/axolotl/utils/lora_merge_efficient.py always computes

delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))

without accounting for the fan_in_fan_out flag in the PEFT config, which will result in incorrect merges when adapters were trained with fan_in_fan_out=True.

Please apply the following change in both the .safetensors branch (around lines 170–175) and the .bin branch (around lines 188–194):

-   delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))
+   delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))
+   if lora_config.fan_in_fan_out:
+       delta = delta.T

• Locations to update:

  • safetensors loop (after line 170)
  • torch.load loop (after line 188)

• Add a unit test with a LoRA adapter configured as fan_in_fan_out=True to verify the transpose is applied correctly.

🤖 Prompt for AI Agents
In src/axolotl/utils/lora_merge_efficient.py around lines 170–175 (safetensors
branch) and around lines 188–194 (torch.load/.bin branch), the merge always
computes delta as scale * (lora_b @ lora_a) and ignores the PEFT config flag
fan_in_fan_out; update both locations to check the adapter config and, when
fan_in_fan_out is True, transpose lora_a and lora_b appropriately (e.g., swap or
transpose operands so multiplication reflects the trained orientation) before
computing delta, then cast back to original dtype as now; also add a unit test
that loads/creates a LoRA adapter with fan_in_fan_out=True, runs the merge, and
asserts the merged tensor matches the expected result when the transpose branch
is applied.

else:
merged_tensors[key] = tensor
else:
state_dict = torch.load(
shard_path, map_location=device
) # nosec B614: loading trusted model weights
for key, tensor in state_dict.items():
total_tensors += 1
lora_a, lora_b = find_lora_weights(lora_state, key)

if lora_a is not None and lora_b is not None:
merged_count += 1
original_dtype = tensor.dtype
tensor_fp32 = tensor.to(torch.float32)
delta = scale * (
lora_b.to(torch.float32) @ lora_a.to(torch.float32)
)
merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype)
else:
merged_tensors[key] = tensor

output_shard_path = output_path / shard_path.name
if safe_tensors and shard_path.suffix == ".safetensors":
safetensors.torch.save_file(
merged_tensors, output_shard_path, metadata=metadata
)
else:
if safe_tensors:
output_shard_path = output_shard_path.with_suffix(".safetensors")
torch.save(merged_tensors, output_shard_path)

del merged_tensors
if device != "cpu":
torch.cuda.empty_cache()

LOG.info(f"Applied LoRA to {merged_count}/{total_tensors} tensors")
8 changes: 7 additions & 1 deletion src/axolotl/utils/schemas/peft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Pydantic models for PEFT-related configuration"""

from typing import Any
from typing import Any, Literal

from pydantic import BaseModel, Field, field_validator, model_validator

Expand Down Expand Up @@ -130,6 +130,12 @@ class LoraConfig(BaseModel):
)

merge_lora: bool | None = None
merge_method: Literal["standard", "memory_efficient"] | None = Field(
default="standard",
json_schema_extra={
"description": "Method to use for LoRA merging. 'standard' loads the full model into memory, 'memory_efficient' processes shards individually to reduce memory usage."
},
)

@model_validator(mode="before")
@classmethod
Expand Down
Loading