Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions src/axolotl/cli/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,45 @@
from typing import Union

import fire
import torch

from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.lora_merge_efficient import merge_lora_sharded_efficient

LOG = get_logger(__name__)


def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
along with the LoRA adapters to combine them into a single base model.
Merges LoRA adapters with base model using either memory-efficient or legacy approach.

Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
merge_method = (
str(getattr(cfg, "merge_method", "")).strip().lower().replace("-", "_")
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge_method can only take values: Literal["legacy", "memory_efficient"] so you don't need this string handling.

if merge_method in {"legacy", "standard"}:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"standard" doesn't exist

LOG.debug("Using legacy LoRA merging method...")
_do_merge_lora_legacy(cfg=cfg)
else:
LOG.debug("Using memory-efficient LoRA merging method...")
try:
_do_merge_lora_efficient(cfg=cfg)
except Exception: # pylint: disable=broad-exception-caught
LOG.exception("Memory-efficient merge failed; falling back to legacy.")
_do_merge_lora_legacy(cfg=cfg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I'd rather have a hard failure here so we know if something is broken

Suggested change
try:
_do_merge_lora_efficient(cfg=cfg)
except Exception: # pylint: disable=broad-exception-caught
LOG.exception("Memory-efficient merge failed; falling back to legacy.")
_do_merge_lora_legacy(cfg=cfg)
_do_merge_lora_efficient(cfg=cfg)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are unsupported combinations (you mentioned DoRA, RSLoRA), we should validate this in the pydantic model and raise an error there.



def _do_merge_lora_legacy(*, cfg: DictDefault) -> None:
"""
Legacy LoRA merging using merge_and_unload.
Loads the full model into memory before merging.
"""
LOG.debug("Using legacy LoRA merging method...")
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True

Expand Down Expand Up @@ -52,6 +74,32 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))


def _do_merge_lora_efficient(*, cfg: DictDefault) -> None:
"""
Memory-efficient LoRA merging using shard-by-shard processing.
Does not load the full model into memory.

Note: Currently only supports standard LoRA, not advanced methods like DoRA or RSLoRA.
Will automatically fall back to legacy method for unsupported configurations.
"""
LOG.debug("Using memory-efficient LoRA merging method...")

output_path = Path(cfg.output_dir) / "merged"
safe_tensors = getattr(cfg, "save_safetensors", True)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Perform memory-efficient merge
merge_lora_sharded_efficient(
base_model_path=cfg.base_model,
lora_adapter_path=cfg.lora_model_dir,
output_path=output_path,
safe_tensors=safe_tensors,
device=device,
)

LOG.debug("Memory-efficient LoRA merge completed successfully!")


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various
Expand Down Expand Up @@ -83,7 +131,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
parsed_cfg.lora_model_dir = parsed_cfg.output_dir
if not Path(parsed_cfg.lora_model_dir).exists():
raise ValueError(
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`"
)

do_merge_lora(cfg=parsed_cfg)
Expand Down
282 changes: 282 additions & 0 deletions src/axolotl/utils/lora_merge_efficient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
"""
Memory-efficient LoRA merging implementation inspired by qlora-pipe.
Processes model shards individually without loading the full model into memory.
"""

import gc
import os
import shutil
from pathlib import Path
from typing import Dict, Optional, Union

import safetensors
import safetensors.torch
import torch
from peft import LoraConfig
from tqdm import tqdm

from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


def find_lora_weights(
lora_state: Dict[str, torch.Tensor], key: str
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Find corresponding LoRA A and B weights for a given key.
"""
clean_key = key[:-7] if key.endswith(".weight") else key

a_key = f"base_model.model.{clean_key}.lora_A.weight"
b_key = f"base_model.model.{clean_key}.lora_B.weight"

lora_a = lora_state.get(a_key)
lora_b = lora_state.get(b_key)

if lora_a is not None and lora_b is not None:
return lora_a, lora_b
return None, None


def get_model_shards(model_path: Path) -> list[Path]:
"""Find all model shards in the given path."""
shards: list[Path] = []

patterns = ["model*.safetensors", "pytorch_model*.bin"]

for pattern in patterns:
shards.extend(model_path.glob(pattern))
if shards:
break

return sorted(shards)


def copy_non_model_files(
input_path: Path, output_path: Path, model_shards: list[Path]
) -> None:
"""
Copy all non-model files to the output directory.
Args:
input_path: Source directory
output_path: Destination directory
model_shards: List of model shard files to skip
"""
LOG.info("Copying non-model files to output directory...")

shard_names = {shard.name for shard in model_shards}

for filepath in input_path.glob("*"):
if filepath.is_dir():
continue
if filepath.name in shard_names:
continue
if (
filepath.name.startswith("model") and filepath.suffix == ".safetensors"
) or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin"):
continue
if filepath.suffix == ".gguf":
continue

LOG.debug(f"Copying {filepath.name} to output")
shutil.copy2(filepath, output_path)


def merge_lora_sharded_efficient(
base_model_path: Union[str, Path],
lora_adapter_path: Union[str, Path],
output_path: Union[str, Path],
device: str = "cpu",
safe_tensors: bool = True,
) -> None:
"""
Memory-efficient LoRA merging that processes shards individually
without loading the full model into memory.
"""
base_model_path = Path(base_model_path)
lora_adapter_path = Path(lora_adapter_path)
output_path = Path(output_path)

if "/" in str(base_model_path) and not base_model_path.exists():
from huggingface_hub import snapshot_download
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be a toplevel import


base_model_path = Path(snapshot_download(str(base_model_path)))

os.makedirs(output_path, exist_ok=True)

config_file = lora_adapter_path / "adapter_config.json"
if not config_file.exists():
raise FileNotFoundError(f"LoRA config not found: {config_file}")

lora_config_dict = LoraConfig.from_json_file(str(config_file))
if not lora_config_dict.get("r") or lora_config_dict["r"] <= 0:
raise ValueError("LoRA config 'r' must be > 0")

unsupported_methods = []

# Check for DoRA (Weight-Decomposed LoRA)
if lora_config_dict.get("use_dora", False):
unsupported_methods.append("DoRA (Weight-Decomposed LoRA)")

# Check for AdaLoRA (Adaptive LoRA)
if lora_config_dict.get("use_adalora", False):
unsupported_methods.append("AdaLoRA (Adaptive LoRA)")

# Check for VeRA (Vector-based Random Matrix Adaptation)
if lora_config_dict.get("use_vera", False):
unsupported_methods.append("VeRA (Vector-based Random Matrix Adaptation)")

# Check for other advanced LoRA variants by task_type
task_type = lora_config_dict.get("task_type", "")
if task_type and task_type not in [
"CAUSAL_LM",
"SEQ_2_SEQ_LM",
"TOKEN_CLS",
"SEQ_CLS",
"QUESTION_ANS",
]:
unsupported_methods.append(f"Task type: {task_type}")

# Check for rank adaptation patterns (AdaLoRA indicators)
if any(
key in lora_config_dict
for key in ["rank_pattern", "alpha_pattern", "target_rank"]
):
unsupported_methods.append("AdaLoRA (rank adaptation detected)")

# Check for advanced initialization methods
init_lora_weights = lora_config_dict.get("init_lora_weights", "")
if init_lora_weights and init_lora_weights not in [
"gaussian",
"loftq",
True,
False,
]:
unsupported_methods.append(f"Advanced initialization: {init_lora_weights}")

if unsupported_methods:
methods_str = ", ".join(unsupported_methods)
raise NotImplementedError(
f"Memory-efficient LoRA merge only supports standard LoRA. "
f"Detected unsupported methods: {methods_str}. "
f"Please use the legacy merge method for advanced LoRA variants."
)

scale = float(lora_config_dict["lora_alpha"]) / float(lora_config_dict["r"])

LOG.debug(f"LoRA scale factor: {scale}")

lora_file = lora_adapter_path / "adapter_model.safetensors"
if not lora_file.exists():
lora_file = lora_adapter_path / "adapter_model.bin"
if not lora_file.exists():
raise FileNotFoundError(
f"LoRA adapter weights not found in {lora_adapter_path}"
)

LOG.debug(f"Loading LoRA weights from {lora_file}")

if lora_file.suffix == ".safetensors":
lora_state = safetensors.torch.load_file(lora_file)
else:
try:
lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614
except TypeError:
lora_state = torch.load(lora_file, map_location="cpu") # nosec B614
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this try/except? can you choose one loading method and stick to it?

LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge")

model_shards = get_model_shards(base_model_path)
if not model_shards:
raise FileNotFoundError(f"No model shards found in {base_model_path}")

LOG.debug(f"Found {len(model_shards)} model shards in {base_model_path}")
copy_non_model_files(base_model_path, output_path, model_shards)

merged_count = 0
total_tensors = 0

for shard_path in tqdm(model_shards, desc="Merging shards"):
merged_tensors = {}
metadata = {}

if shard_path.suffix == ".safetensors":
with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
if hasattr(f, "metadata") and f.metadata():
metadata = f.metadata()

for key in f.keys():
total_tensors += 1
tensor = f.get_tensor(key)
lora_a, lora_b = find_lora_weights(lora_state, key)

if lora_a is not None and lora_b is not None:
merged_count += 1
LOG.debug(
f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}"
)

original_dtype = tensor.dtype
base_fp32 = tensor.to(device).to(torch.float32)
a_fp32 = lora_a.to(device).to(torch.float32)
b_fp32 = lora_b.to(device).to(torch.float32)
delta = scale * (b_fp32 @ a_fp32)
if bool(
lora_config_dict.get("fan_in_fan_out", False)
or lora_config_dict.get("lora_fan_in_fan_out", False)
):
delta = delta.T
merged_tensors[key] = (
(base_fp32 + delta).to(original_dtype).detach().cpu()
)
del base_fp32, a_fp32, b_fp32, delta
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this appears to be duplicated below, can be factored out into a helper method

else:
merged_tensors[key] = tensor.detach().cpu()
else:
state_dict = torch.load( # nosec B614: loading trusted model weights
shard_path, map_location="cpu", weights_only=True
)
for key, tensor in state_dict.items():
total_tensors += 1
lora_a, lora_b = find_lora_weights(lora_state, key)

if lora_a is not None and lora_b is not None:
merged_count += 1
original_dtype = tensor.dtype
base_fp32 = tensor.to(device).to(torch.float32)
a_fp32 = lora_a.to(device).to(torch.float32)
b_fp32 = lora_b.to(device).to(torch.float32)
delta = scale * (b_fp32 @ a_fp32)
if bool(
lora_config_dict.get("fan_in_fan_out", False)
or lora_config_dict.get("lora_fan_in_fan_out", False)
):
delta = delta.T
merged_tensors[key] = (
(base_fp32 + delta).to(original_dtype).detach().cpu()
)
del base_fp32, a_fp32, b_fp32, delta
else:
merged_tensors[key] = tensor.detach().cpu()

output_shard_path = output_path / shard_path.name
merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
if shard_path.suffix == ".safetensors":
safetensors.torch.save_file(
merged_tensors, output_shard_path, metadata=metadata
)
else:
if safe_tensors:
LOG.warning(
"safe_tensors=True requested but input shards are .bin; preserving .bin format "
"to avoid index mismatches."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing. if the user requests safe_tensors, shouldn't we convert them to safetensors?

torch.save(merged_tensors, output_shard_path)

del merged_tensors
if device != "cpu" and torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

LOG.info(f"Applied LoRA to {merged_count}/{total_tensors} tensors")
8 changes: 7 additions & 1 deletion src/axolotl/utils/schemas/peft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Pydantic models for PEFT-related configuration"""

from typing import Any
from typing import Any, Literal

from pydantic import BaseModel, Field, field_validator, model_validator

Expand Down Expand Up @@ -140,6 +140,12 @@ class LoraConfig(BaseModel):
)

merge_lora: bool | None = None
merge_method: Literal["legacy", "memory_efficient"] | None = Field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly prefer merge_lora_method so it's unambiguous

default="memory_efficient",
json_schema_extra={
"description": "Method to use for LoRA merging. 'memory_efficient' (default) processes shards individually to reduce memory usage, 'legacy' loads the full model into memory."
},
)

@model_validator(mode="before")
@classmethod
Expand Down