From ac27f1d5a46f725cd382f996ac059c58b5d71e3d Mon Sep 17 00:00:00 2001 From: ved1beta Date: Fri, 22 Aug 2025 00:44:39 +0530 Subject: [PATCH 01/17] merge_method added --- src/axolotl/cli/merge_lora.py | 40 ++++++++++++++++++++++++++++--- src/axolotl/utils/schemas/peft.py | 8 ++++++- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 31fad1b297..6c82de6c40 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -9,18 +9,31 @@ from axolotl.cli.utils import load_model_and_tokenizer from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger +from axolotl.utils.lora_merge_efficient import merge_lora_sharded_efficient LOG = get_logger(__name__) def do_merge_lora(*, cfg: DictDefault) -> None: """ - Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config - along with the LoRA adapters to combine them into a single base model. + Merges LoRA adapters with base model using either standard or memory-efficient approach. Args: cfg: Dictionary mapping `axolotl` config keys to values. """ + merge_method = getattr(cfg, "merge_method", "standard") + if merge_method == "memory_efficient": + _do_merge_lora_efficient(cfg=cfg) + else: + _do_merge_lora_standard(cfg=cfg) + + +def _do_merge_lora_standard(*, cfg: DictDefault) -> None: + """ + Standard LoRA merging using `merge_and_unload`. + Loads the full model into memory before merging. + """ + LOG.info("Using standard LoRA merging method...") model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg) safe_serialization = cfg.save_safetensors is True @@ -49,6 +62,27 @@ def do_merge_lora(*, cfg: DictDefault) -> None: processor.save_pretrained(str(Path(cfg.output_dir) / "merged")) +def _do_merge_lora_efficient(*, cfg: DictDefault) -> None: + """ + Memory-efficient LoRA merging using shard-by-shard processing. + Does not load the full model into memory. + """ + LOG.info("Using memory-efficient LoRA merging method...") + + output_path = Path(cfg.output_dir) / "merged" + safe_tensors = getattr(cfg, "save_safetensors", True) + + # Perform memory-efficient merge + merge_lora_sharded_efficient( + base_model_path=cfg.base_model, + lora_adapter_path=cfg.lora_model_dir, + output_path=output_path, + safe_tensors=safe_tensors, + ) + + LOG.info("Memory-efficient LoRA merge completed successfully!") + + def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: """ Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various @@ -80,7 +114,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: parsed_cfg.lora_model_dir = parsed_cfg.output_dir if not Path(parsed_cfg.lora_model_dir).exists(): raise ValueError( - f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist." + f"Target directory for LoRA merged model does not exist: `{parsed_cfg.lora_model_dir}`" ) do_merge_lora(cfg=parsed_cfg) diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index de29521cb4..e471595def 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -1,6 +1,6 @@ """Pydantic models for PEFT-related configuration""" -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator, model_validator @@ -130,6 +130,12 @@ class LoraConfig(BaseModel): ) merge_lora: bool | None = None + merge_method: Literal["standard", "memory_efficient"] | None = Field( + default="standard", + json_schema_extra={ + "description": "Method to use for LoRA merging. 'standard' loads the full model into memory, 'memory_efficient' processes shards individually to reduce memory usage." + }, + ) @model_validator(mode="before") @classmethod From d0c01169dfa25007d77cdd77c29c96b8867ecd08 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Fri, 22 Aug 2025 17:08:15 +0530 Subject: [PATCH 02/17] merge_efficient core implement --- src/axolotl/utils/lora_merge_efficient.py | 214 ++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 src/axolotl/utils/lora_merge_efficient.py diff --git a/src/axolotl/utils/lora_merge_efficient.py b/src/axolotl/utils/lora_merge_efficient.py new file mode 100644 index 0000000000..cb44abd31a --- /dev/null +++ b/src/axolotl/utils/lora_merge_efficient.py @@ -0,0 +1,214 @@ +""" +Memory-efficient LoRA merging implementation inspired by qlora-pipe. +Processes model shards individually without loading the full model into memory. +""" + +import os +import re +import shutil +from pathlib import Path +from typing import Dict, Optional, Union + +import safetensors.torch +import torch +from peft import LoraConfig +from tqdm import tqdm + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def find_lora_weights( + lora_state: Dict[str, torch.Tensor], key: str +) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Find corresponding LoRA A and B weights for a given key. + """ + clean_key = key.strip(".weight") + clean_key = re.sub(r"^(base_model\.model\.|language_model\.)", "", clean_key) + + lora_a = None + lora_b = None + + for lora_key, lora_weight in lora_state.items(): + if clean_key in lora_key: + if "lora_A" in lora_key: + lora_a = lora_weight + elif "lora_B" in lora_key: + lora_b = lora_weight + + if lora_a is not None and lora_b is not None: + return lora_a, lora_b + return None, None + + +def get_model_shards(model_path: Path) -> list[Path]: + """Find all model shards in the given path.""" + shards = list[Path]() + + patterns = ["model*.safetensors", "model*.bin", "pytorch_model*.bin"] + + for pattern in patterns: + shards.extend(model_path.glob(pattern)) + if shards: + break + + return sorted(shards) + + +def copy_non_model_files( + input_path: Path, output_path: Path, model_shards: list[Path] +) -> None: + """ + Copy all non-model files to the output directory. + + Args: + input_path: Source directory + output_path: Destination directory + model_shards: List of model shard files to skip + """ + LOG.info("Copying non-model files to output directory...") + + shard_names = {shard.name for shard in model_shards} + + for filepath in input_path.glob("*"): + if filepath.is_dir(): + continue + if filepath.name in shard_names: + continue + if filepath.suffix == ".gguf": + continue + if filepath.name.startswith("model") and filepath.suffix == ".safetensors": + continue + + LOG.debug(f"Copying {filepath.name} to output") + shutil.copy(filepath, output_path) + + +def merge_lora_sharded_efficient( + base_model_path: Union[str, Path], + lora_adapter_path: Union[str, Path], + output_path: Union[str, Path], + device: str = "cuda", + safe_tensors: bool = True, +) -> None: + """ + Memory-efficient LoRA merging that processes shards individually + without loading the full model into memory. + """ + base_model_path = Path(base_model_path) + lora_adapter_path = Path(lora_adapter_path) + output_path = Path(output_path) + + if "/" in str(base_model_path) and not base_model_path.exists(): + from huggingface_hub import snapshot_download + + base_model_path = Path(snapshot_download(str(base_model_path))) + + os.makedirs(output_path, exist_ok=True) + + config_file = lora_adapter_path / "adapter_config.json" + if not config_file.exists(): + raise FileNotFoundError(f"LoRA config not found: {config_file}") + + lora_config = LoraConfig.from_json_file(config_file) + scale = lora_config.lora_alpha / lora_config.r + + LOG.info(f"LoRA scale factor: {scale}") + + lora_file = lora_adapter_path / "adapter_model.safetensors" + if not lora_file.exists(): + lora_file = lora_adapter_path / "adapter_model.bin" + if not lora_file.exists(): + raise FileNotFoundError( + f"LoRA adapter weights not found in {lora_adapter_path}" + ) + + LOG.info(f"Loading LoRA weights from {lora_file}") + + if lora_file.suffix == ".safetensors": + lora_state = safetensors.torch.load_file(lora_file) + else: + lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) + + if device != "cpu": + LOG.info(f"Moving LoRA weights to {device}") + for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"): + lora_state[key] = value.to(device) + + model_shards = get_model_shards(base_model_path) + if not model_shards: + raise FileNotFoundError(f"No model shards found in {base_model_path}") + + LOG.info(f"Found {len(model_shards)} model shards") + copy_non_model_files(base_model_path, output_path, model_shards) + + merged_count = 0 + total_tensors = 0 + + for shard_path in tqdm(model_shards, desc="Merging shards"): + merged_tensors = {} + metadata = {} + + if shard_path.suffix == ".safetensors": + with safetensors.safe_open(shard_path, framework="pt", device=device) as f: + if hasattr(f, "metadata") and f.metadata(): + metadata = f.metadata() + + for key in f.keys(): + total_tensors += 1 + tensor = f.get_tensor(key) + lora_a, lora_b = find_lora_weights(lora_state, key) + + if lora_a is not None and lora_b is not None: + merged_count += 1 + LOG.debug( + f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}" + ) + + original_dtype = tensor.dtype + tensor_fp32 = tensor.to(torch.float32) + + delta = scale * ( + lora_b.to(torch.float32) @ lora_a.to(torch.float32) + ) + + merged_tensor = (tensor_fp32 + delta).to(original_dtype) + merged_tensors[key] = merged_tensor + else: + merged_tensors[key] = tensor + else: + state_dict = torch.load( + shard_path, map_location=device + ) # nosec B614: loading trusted model weights + for key, tensor in state_dict.items(): + total_tensors += 1 + lora_a, lora_b = find_lora_weights(lora_state, key) + + if lora_a is not None and lora_b is not None: + merged_count += 1 + original_dtype = tensor.dtype + tensor_fp32 = tensor.to(torch.float32) + delta = scale * ( + lora_b.to(torch.float32) @ lora_a.to(torch.float32) + ) + merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype) + else: + merged_tensors[key] = tensor + + output_shard_path = output_path / shard_path.name + if safe_tensors and shard_path.suffix == ".safetensors": + safetensors.torch.save_file( + merged_tensors, output_shard_path, metadata=metadata + ) + else: + if safe_tensors: + output_shard_path = output_shard_path.with_suffix(".safetensors") + torch.save(merged_tensors, output_shard_path) + + del merged_tensors + if device != "cpu": + torch.cuda.empty_cache() + + LOG.info(f"Applied LoRA to {merged_count}/{total_tensors} tensors") From 95f224c2b5a7cd102332c3150cf14fe5943df191 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Fri, 22 Aug 2025 18:55:04 +0530 Subject: [PATCH 03/17] Update src/axolotl/cli/merge_lora.py Co-authored-by: Wing Lian --- src/axolotl/cli/merge_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 6c82de6c40..943702f49c 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -114,7 +114,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: parsed_cfg.lora_model_dir = parsed_cfg.output_dir if not Path(parsed_cfg.lora_model_dir).exists(): raise ValueError( - f"Target directory for LoRA merged model does not exist: `{parsed_cfg.lora_model_dir}`" + f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`" ) do_merge_lora(cfg=parsed_cfg) From 4ccf5ae162a80ef454d8bc43a84ecc6de73b8fdd Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Fri, 22 Aug 2025 19:19:52 +0530 Subject: [PATCH 04/17] Update src/axolotl/utils/lora_merge_efficient.py Co-authored-by: Wing Lian --- src/axolotl/utils/lora_merge_efficient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/lora_merge_efficient.py b/src/axolotl/utils/lora_merge_efficient.py index cb44abd31a..79922d3b7a 100644 --- a/src/axolotl/utils/lora_merge_efficient.py +++ b/src/axolotl/utils/lora_merge_efficient.py @@ -125,7 +125,7 @@ def merge_lora_sharded_efficient( f"LoRA adapter weights not found in {lora_adapter_path}" ) - LOG.info(f"Loading LoRA weights from {lora_file}") + LOG.debug(f"Loading LoRA weights from {lora_file}") if lora_file.suffix == ".safetensors": lora_state = safetensors.torch.load_file(lora_file) From f094ea24abc3a9efa56eadad153aaf7bf2084525 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Fri, 22 Aug 2025 20:14:51 +0530 Subject: [PATCH 05/17] standard to leagcy + rstrip + try/except for do_merge_lora_efficient(cfg=cfg) --- src/axolotl/cli/merge_lora.py | 28 +++++++++++++++-------- src/axolotl/utils/lora_merge_efficient.py | 15 +++++------- src/axolotl/utils/schemas/peft.py | 6 ++--- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 943702f49c..9363cdff94 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -16,24 +16,31 @@ def do_merge_lora(*, cfg: DictDefault) -> None: """ - Merges LoRA adapters with base model using either standard or memory-efficient approach. + Merges LoRA adapters with base model using either memory-efficient or legacy approach. Args: cfg: Dictionary mapping `axolotl` config keys to values. """ - merge_method = getattr(cfg, "merge_method", "standard") - if merge_method == "memory_efficient": - _do_merge_lora_efficient(cfg=cfg) + merge_method = getattr(cfg, "merge_method", "memory_efficient") + LOG.info(f"Using {merge_method} LoRA merge method") + + if merge_method == "legacy": + _do_merge_lora_legacy(cfg=cfg) else: - _do_merge_lora_standard(cfg=cfg) + try: + _do_merge_lora_efficient(cfg=cfg) + except RuntimeError as e: + LOG.error(f"Memory-efficient merge failed: {e}") + LOG.info("Falling back to legacy merge method...") + _do_merge_lora_legacy(cfg=cfg) -def _do_merge_lora_standard(*, cfg: DictDefault) -> None: +def _do_merge_lora_legacy(*, cfg: DictDefault) -> None: """ - Standard LoRA merging using `merge_and_unload`. + Legacy LoRA merging using `merge_and_unload`. Loads the full model into memory before merging. """ - LOG.info("Using standard LoRA merging method...") + LOG.info("Using legacy LoRA merging method...") model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg) safe_serialization = cfg.save_safetensors is True @@ -66,6 +73,9 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None: """ Memory-efficient LoRA merging using shard-by-shard processing. Does not load the full model into memory. + + Note: Currently only supports standard LoRA, not advanced methods like DoRA or RSLoRA. + Will automatically fall back to legacy method for unsupported configurations. """ LOG.info("Using memory-efficient LoRA merging method...") @@ -114,7 +124,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: parsed_cfg.lora_model_dir = parsed_cfg.output_dir if not Path(parsed_cfg.lora_model_dir).exists(): raise ValueError( - f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`" + f"Target directory for LoRA merged model does not exist: `{parsed_cfg.lora_model_dir}`" ) do_merge_lora(cfg=parsed_cfg) diff --git a/src/axolotl/utils/lora_merge_efficient.py b/src/axolotl/utils/lora_merge_efficient.py index 79922d3b7a..2ee54f49af 100644 --- a/src/axolotl/utils/lora_merge_efficient.py +++ b/src/axolotl/utils/lora_merge_efficient.py @@ -4,7 +4,6 @@ """ import os -import re import shutil from pathlib import Path from typing import Dict, Optional, Union @@ -25,18 +24,16 @@ def find_lora_weights( """ Find corresponding LoRA A and B weights for a given key. """ - clean_key = key.strip(".weight") - clean_key = re.sub(r"^(base_model\.model\.|language_model\.)", "", clean_key) + clean_key = key.rstrip(".weight") lora_a = None lora_b = None for lora_key, lora_weight in lora_state.items(): - if clean_key in lora_key: - if "lora_A" in lora_key: - lora_a = lora_weight - elif "lora_B" in lora_key: - lora_b = lora_weight + if lora_key.endswith(f"{clean_key}.lora_A.weight"): + lora_a = lora_weight + elif lora_key.endswith(f"{clean_key}.lora_B.weight"): + lora_b = lora_weight if lora_a is not None and lora_b is not None: return lora_a, lora_b @@ -47,7 +44,7 @@ def get_model_shards(model_path: Path) -> list[Path]: """Find all model shards in the given path.""" shards = list[Path]() - patterns = ["model*.safetensors", "model*.bin", "pytorch_model*.bin"] + patterns = ["model*.safetensors", "pytorch_model*.bin"] for pattern in patterns: shards.extend(model_path.glob(pattern)) diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index e471595def..b1457015ea 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -130,10 +130,10 @@ class LoraConfig(BaseModel): ) merge_lora: bool | None = None - merge_method: Literal["standard", "memory_efficient"] | None = Field( - default="standard", + merge_method: Literal["legacy", "memory_efficient"] | None = Field( + default="memory_efficient", json_schema_extra={ - "description": "Method to use for LoRA merging. 'standard' loads the full model into memory, 'memory_efficient' processes shards individually to reduce memory usage." + "description": "Method to use for LoRA merging. 'memory_efficient' (default) processes shards individually to reduce memory usage, 'legacy' loads the full model into memory." }, ) From fe157bdc4aaac50277c032c784fda0ecfe7266f0 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Sat, 23 Aug 2025 10:26:55 +0530 Subject: [PATCH 06/17] fix: 'dict' object has no attribute 'lora_alpha' --- src/axolotl/utils/lora_merge_efficient.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/lora_merge_efficient.py b/src/axolotl/utils/lora_merge_efficient.py index 2ee54f49af..52e2049561 100644 --- a/src/axolotl/utils/lora_merge_efficient.py +++ b/src/axolotl/utils/lora_merge_efficient.py @@ -110,7 +110,7 @@ def merge_lora_sharded_efficient( raise FileNotFoundError(f"LoRA config not found: {config_file}") lora_config = LoraConfig.from_json_file(config_file) - scale = lora_config.lora_alpha / lora_config.r + scale = lora_config["lora_alpha"] / lora_config["r"] LOG.info(f"LoRA scale factor: {scale}") From d63de30653c2580de69d4292e9927052f90f79f0 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Sat, 23 Aug 2025 13:05:51 +0530 Subject: [PATCH 07/17] into -> debug --- src/axolotl/cli/merge_lora.py | 4 ++-- src/axolotl/utils/lora_merge_efficient.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 9363cdff94..e79be1e4dc 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -40,7 +40,7 @@ def _do_merge_lora_legacy(*, cfg: DictDefault) -> None: Legacy LoRA merging using `merge_and_unload`. Loads the full model into memory before merging. """ - LOG.info("Using legacy LoRA merging method...") + LOG.debug("Using legacy LoRA merging method...") model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg) safe_serialization = cfg.save_safetensors is True @@ -90,7 +90,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None: safe_tensors=safe_tensors, ) - LOG.info("Memory-efficient LoRA merge completed successfully!") + LOG.debug("Memory-efficient LoRA merge completed successfully!") def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: diff --git a/src/axolotl/utils/lora_merge_efficient.py b/src/axolotl/utils/lora_merge_efficient.py index 52e2049561..29255f2ae7 100644 --- a/src/axolotl/utils/lora_merge_efficient.py +++ b/src/axolotl/utils/lora_merge_efficient.py @@ -112,7 +112,7 @@ def merge_lora_sharded_efficient( lora_config = LoraConfig.from_json_file(config_file) scale = lora_config["lora_alpha"] / lora_config["r"] - LOG.info(f"LoRA scale factor: {scale}") + LOG.debug(f"LoRA scale factor: {scale}") lora_file = lora_adapter_path / "adapter_model.safetensors" if not lora_file.exists(): @@ -130,7 +130,7 @@ def merge_lora_sharded_efficient( lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) if device != "cpu": - LOG.info(f"Moving LoRA weights to {device}") + LOG.debug(f"Moving LoRA weights to {device}") for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"): lora_state[key] = value.to(device) @@ -138,7 +138,7 @@ def merge_lora_sharded_efficient( if not model_shards: raise FileNotFoundError(f"No model shards found in {base_model_path}") - LOG.info(f"Found {len(model_shards)} model shards") + LOG.debug(f"Found {len(model_shards)} model shards") copy_non_model_files(base_model_path, output_path, model_shards) merged_count = 0 From 4b2fc64ca16e8022699656383a70f418119407f0 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Tue, 2 Sep 2025 20:30:54 +0530 Subject: [PATCH 08/17] lint --- src/axolotl/cli/merge_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index e79be1e4dc..b8f40c7e17 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -37,7 +37,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None: def _do_merge_lora_legacy(*, cfg: DictDefault) -> None: """ - Legacy LoRA merging using `merge_and_unload`. + Legacy LoRA merging using merge_and_unload. Loads the full model into memory before merging. """ LOG.debug("Using legacy LoRA merging method...") From 1c1c3aba54851425dd79ff9d4c6f5b0ccb448789 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Tue, 2 Sep 2025 20:41:33 +0530 Subject: [PATCH 09/17] lint2 --- src/axolotl/utils/lora_merge_efficient.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/utils/lora_merge_efficient.py b/src/axolotl/utils/lora_merge_efficient.py index 29255f2ae7..863cf06c69 100644 --- a/src/axolotl/utils/lora_merge_efficient.py +++ b/src/axolotl/utils/lora_merge_efficient.py @@ -176,9 +176,7 @@ def merge_lora_sharded_efficient( else: merged_tensors[key] = tensor else: - state_dict = torch.load( - shard_path, map_location=device - ) # nosec B614: loading trusted model weights + state_dict = torch.load(shard_path, map_location=device) # nosec B614: loading trusted model weights for key, tensor in state_dict.items(): total_tensors += 1 lora_a, lora_b = find_lora_weights(lora_state, key) From d01bb1baf000857d2b4588b323b543837f33e487 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Sat, 13 Sep 2025 11:02:28 +0530 Subject: [PATCH 10/17] moved everythign to cpu + peformance improvments --- src/axolotl/cli/merge_lora.py | 22 +-- src/axolotl/utils/lora_merge_efficient.py | 155 ++++++++++++++++------ 2 files changed, 128 insertions(+), 49 deletions(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index b8f40c7e17..93d077f763 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -4,6 +4,7 @@ from typing import Union import fire +import torch from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer @@ -21,17 +22,18 @@ def do_merge_lora(*, cfg: DictDefault) -> None: Args: cfg: Dictionary mapping `axolotl` config keys to values. """ - merge_method = getattr(cfg, "merge_method", "memory_efficient") - LOG.info(f"Using {merge_method} LoRA merge method") - - if merge_method == "legacy": + merge_method = ( + str(getattr(cfg, "merge_method", "")).strip().lower().replace("-", "_") + ) + if merge_method in {"legacy", "standard"}: + LOG.debug("Using legacy LoRA merging method...") _do_merge_lora_legacy(cfg=cfg) else: + LOG.debug("Using memory-efficient LoRA merging method...") try: _do_merge_lora_efficient(cfg=cfg) - except RuntimeError as e: - LOG.error(f"Memory-efficient merge failed: {e}") - LOG.info("Falling back to legacy merge method...") + except Exception: # pylint: disable=broad-exception-caught + LOG.exception("Memory-efficient merge failed; falling back to legacy.") _do_merge_lora_legacy(cfg=cfg) @@ -77,10 +79,11 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None: Note: Currently only supports standard LoRA, not advanced methods like DoRA or RSLoRA. Will automatically fall back to legacy method for unsupported configurations. """ - LOG.info("Using memory-efficient LoRA merging method...") + LOG.debug("Using memory-efficient LoRA merging method...") output_path = Path(cfg.output_dir) / "merged" safe_tensors = getattr(cfg, "save_safetensors", True) + device = "cuda" if torch.cuda.is_available() else "cpu" # Perform memory-efficient merge merge_lora_sharded_efficient( @@ -88,6 +91,7 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None: lora_adapter_path=cfg.lora_model_dir, output_path=output_path, safe_tensors=safe_tensors, + device=device, ) LOG.debug("Memory-efficient LoRA merge completed successfully!") @@ -124,7 +128,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: parsed_cfg.lora_model_dir = parsed_cfg.output_dir if not Path(parsed_cfg.lora_model_dir).exists(): raise ValueError( - f"Target directory for LoRA merged model does not exist: `{parsed_cfg.lora_model_dir}`" + f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`" ) do_merge_lora(cfg=parsed_cfg) diff --git a/src/axolotl/utils/lora_merge_efficient.py b/src/axolotl/utils/lora_merge_efficient.py index 863cf06c69..9c564fa126 100644 --- a/src/axolotl/utils/lora_merge_efficient.py +++ b/src/axolotl/utils/lora_merge_efficient.py @@ -3,11 +3,13 @@ Processes model shards individually without loading the full model into memory. """ +import gc import os import shutil from pathlib import Path from typing import Dict, Optional, Union +import safetensors import safetensors.torch import torch from peft import LoraConfig @@ -24,16 +26,13 @@ def find_lora_weights( """ Find corresponding LoRA A and B weights for a given key. """ - clean_key = key.rstrip(".weight") + clean_key = key[:-7] if key.endswith(".weight") else key - lora_a = None - lora_b = None + a_key = f"base_model.model.{clean_key}.lora_A.weight" + b_key = f"base_model.model.{clean_key}.lora_B.weight" - for lora_key, lora_weight in lora_state.items(): - if lora_key.endswith(f"{clean_key}.lora_A.weight"): - lora_a = lora_weight - elif lora_key.endswith(f"{clean_key}.lora_B.weight"): - lora_b = lora_weight + lora_a = lora_state.get(a_key) + lora_b = lora_state.get(b_key) if lora_a is not None and lora_b is not None: return lora_a, lora_b @@ -42,7 +41,7 @@ def find_lora_weights( def get_model_shards(model_path: Path) -> list[Path]: """Find all model shards in the given path.""" - shards = list[Path]() + shards: list[Path] = [] patterns = ["model*.safetensors", "pytorch_model*.bin"] @@ -74,20 +73,22 @@ def copy_non_model_files( continue if filepath.name in shard_names: continue - if filepath.suffix == ".gguf": + if ( + filepath.name.startswith("model") and filepath.suffix == ".safetensors" + ) or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin"): continue - if filepath.name.startswith("model") and filepath.suffix == ".safetensors": + if filepath.suffix == ".gguf": continue LOG.debug(f"Copying {filepath.name} to output") - shutil.copy(filepath, output_path) + shutil.copy2(filepath, output_path) def merge_lora_sharded_efficient( base_model_path: Union[str, Path], lora_adapter_path: Union[str, Path], output_path: Union[str, Path], - device: str = "cuda", + device: str = "cpu", safe_tensors: bool = True, ) -> None: """ @@ -109,8 +110,61 @@ def merge_lora_sharded_efficient( if not config_file.exists(): raise FileNotFoundError(f"LoRA config not found: {config_file}") - lora_config = LoraConfig.from_json_file(config_file) - scale = lora_config["lora_alpha"] / lora_config["r"] + lora_config_dict = LoraConfig.from_json_file(str(config_file)) + if not lora_config_dict.get("r") or lora_config_dict["r"] <= 0: + raise ValueError("LoRA config 'r' must be > 0") + + unsupported_methods = [] + + # Check for DoRA (Weight-Decomposed LoRA) + if lora_config_dict.get("use_dora", False): + unsupported_methods.append("DoRA (Weight-Decomposed LoRA)") + + # Check for AdaLoRA (Adaptive LoRA) + if lora_config_dict.get("use_adalora", False): + unsupported_methods.append("AdaLoRA (Adaptive LoRA)") + + # Check for VeRA (Vector-based Random Matrix Adaptation) + if lora_config_dict.get("use_vera", False): + unsupported_methods.append("VeRA (Vector-based Random Matrix Adaptation)") + + # Check for other advanced LoRA variants by task_type + task_type = lora_config_dict.get("task_type", "") + if task_type and task_type not in [ + "CAUSAL_LM", + "SEQ_2_SEQ_LM", + "TOKEN_CLS", + "SEQ_CLS", + "QUESTION_ANS", + ]: + unsupported_methods.append(f"Task type: {task_type}") + + # Check for rank adaptation patterns (AdaLoRA indicators) + if any( + key in lora_config_dict + for key in ["rank_pattern", "alpha_pattern", "target_rank"] + ): + unsupported_methods.append("AdaLoRA (rank adaptation detected)") + + # Check for advanced initialization methods + init_lora_weights = lora_config_dict.get("init_lora_weights", "") + if init_lora_weights and init_lora_weights not in [ + "gaussian", + "loftq", + True, + False, + ]: + unsupported_methods.append(f"Advanced initialization: {init_lora_weights}") + + if unsupported_methods: + methods_str = ", ".join(unsupported_methods) + raise NotImplementedError( + f"Memory-efficient LoRA merge only supports standard LoRA. " + f"Detected unsupported methods: {methods_str}. " + f"Please use the legacy merge method for advanced LoRA variants." + ) + + scale = float(lora_config_dict["lora_alpha"]) / float(lora_config_dict["r"]) LOG.debug(f"LoRA scale factor: {scale}") @@ -127,18 +181,19 @@ def merge_lora_sharded_efficient( if lora_file.suffix == ".safetensors": lora_state = safetensors.torch.load_file(lora_file) else: - lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) - - if device != "cpu": - LOG.debug(f"Moving LoRA weights to {device}") - for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"): - lora_state[key] = value.to(device) + try: + lora_state = torch.load( + lora_file, map_location="cpu", weights_only=True + ) # nosec B614 + except TypeError: + lora_state = torch.load(lora_file, map_location="cpu") # nosec B614 + LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge") model_shards = get_model_shards(base_model_path) if not model_shards: raise FileNotFoundError(f"No model shards found in {base_model_path}") - LOG.debug(f"Found {len(model_shards)} model shards") + LOG.debug(f"Found {len(model_shards)} model shards in {base_model_path}") copy_non_model_files(base_model_path, output_path, model_shards) merged_count = 0 @@ -149,7 +204,7 @@ def merge_lora_sharded_efficient( metadata = {} if shard_path.suffix == ".safetensors": - with safetensors.safe_open(shard_path, framework="pt", device=device) as f: + with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f: if hasattr(f, "metadata") and f.metadata(): metadata = f.metadata() @@ -165,18 +220,25 @@ def merge_lora_sharded_efficient( ) original_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - - delta = scale * ( - lora_b.to(torch.float32) @ lora_a.to(torch.float32) + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if bool( + lora_config_dict.get("fan_in_fan_out", False) + or lora_config_dict.get("lora_fan_in_fan_out", False) + ): + delta = delta.T + merged_tensors[key] = ( + (base_fp32 + delta).to(original_dtype).detach().cpu() ) - - merged_tensor = (tensor_fp32 + delta).to(original_dtype) - merged_tensors[key] = merged_tensor + del base_fp32, a_fp32, b_fp32, delta else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.detach().cpu() else: - state_dict = torch.load(shard_path, map_location=device) # nosec B614: loading trusted model weights + state_dict = torch.load( # nosec B614: loading trusted model weights + shard_path, map_location="cpu", weights_only=True + ) for key, tensor in state_dict.items(): total_tensors += 1 lora_a, lora_b = find_lora_weights(lora_state, key) @@ -184,26 +246,39 @@ def merge_lora_sharded_efficient( if lora_a is not None and lora_b is not None: merged_count += 1 original_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - delta = scale * ( - lora_b.to(torch.float32) @ lora_a.to(torch.float32) + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if bool( + lora_config_dict.get("fan_in_fan_out", False) + or lora_config_dict.get("lora_fan_in_fan_out", False) + ): + delta = delta.T + merged_tensors[key] = ( + (base_fp32 + delta).to(original_dtype).detach().cpu() ) - merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype) + del base_fp32, a_fp32, b_fp32, delta else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.detach().cpu() output_shard_path = output_path / shard_path.name - if safe_tensors and shard_path.suffix == ".safetensors": + merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} + if shard_path.suffix == ".safetensors": safetensors.torch.save_file( merged_tensors, output_shard_path, metadata=metadata ) else: if safe_tensors: - output_shard_path = output_shard_path.with_suffix(".safetensors") + LOG.warning( + "safe_tensors=True requested but input shards are .bin; preserving .bin format " + "to avoid index mismatches." + ) torch.save(merged_tensors, output_shard_path) del merged_tensors - if device != "cpu": + if device != "cpu" and torch.cuda.is_available(): torch.cuda.empty_cache() + gc.collect() LOG.info(f"Applied LoRA to {merged_count}/{total_tensors} tensors") From f94415fa27d8be155cb2d83d69e107a7d8f02b7e Mon Sep 17 00:00:00 2001 From: ved1beta Date: Sat, 13 Sep 2025 11:04:23 +0530 Subject: [PATCH 11/17] lint --- src/axolotl/utils/lora_merge_efficient.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/utils/lora_merge_efficient.py b/src/axolotl/utils/lora_merge_efficient.py index 9c564fa126..f8025bba8b 100644 --- a/src/axolotl/utils/lora_merge_efficient.py +++ b/src/axolotl/utils/lora_merge_efficient.py @@ -182,9 +182,7 @@ def merge_lora_sharded_efficient( lora_state = safetensors.torch.load_file(lora_file) else: try: - lora_state = torch.load( - lora_file, map_location="cpu", weights_only=True - ) # nosec B614 + lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614 except TypeError: lora_state = torch.load(lora_file, map_location="cpu") # nosec B614 LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge") From 86e86a7bae6e6bc7673e5b9b80a7c6d2feb056bd Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Wed, 24 Sep 2025 21:56:43 +0530 Subject: [PATCH 12/17] Update src/axolotl/cli/merge_lora.py Co-authored-by: Dan Saunders --- src/axolotl/cli/merge_lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 8ef0e653d9..4a1074e822 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -88,7 +88,6 @@ def _do_merge_lora_efficient(*, cfg: DictDefault) -> None: safe_tensors = getattr(cfg, "save_safetensors", True) device = "cuda" if torch.cuda.is_available() else "cpu" - # Perform memory-efficient merge merge_lora_sharded_efficient( base_model_path=cfg.base_model, lora_adapter_path=cfg.lora_model_dir, From 3384fc5e1ec71bfa30598f5d54086591c004a664 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Wed, 24 Sep 2025 21:57:14 +0530 Subject: [PATCH 13/17] Update src/axolotl/cli/merge_lora.py Co-authored-by: Dan Saunders --- src/axolotl/cli/merge_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 4a1074e822..2776a27859 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -25,7 +25,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None: merge_method = ( str(getattr(cfg, "merge_method", "")).strip().lower().replace("-", "_") ) - if merge_method in {"legacy", "standard"}: + if merge_method == "legacy": LOG.debug("Using legacy LoRA merging method...") _do_merge_lora_legacy(cfg=cfg) else: From d2ce1ab6a1530c9859a535260e3fcbe555041f98 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Wed, 24 Sep 2025 22:31:08 +0530 Subject: [PATCH 14/17] string handeling + try except remove --- src/axolotl/cli/merge_lora.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 2776a27859..a7fe553c4f 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -22,19 +22,13 @@ def do_merge_lora(*, cfg: DictDefault) -> None: Args: cfg: Dictionary mapping `axolotl` config keys to values. """ - merge_method = ( - str(getattr(cfg, "merge_method", "")).strip().lower().replace("-", "_") - ) + merge_method = str(getattr(cfg, "merge_method", "")) if merge_method == "legacy": LOG.debug("Using legacy LoRA merging method...") _do_merge_lora_legacy(cfg=cfg) else: LOG.debug("Using memory-efficient LoRA merging method...") - try: - _do_merge_lora_efficient(cfg=cfg) - except Exception: # pylint: disable=broad-exception-caught - LOG.exception("Memory-efficient merge failed; falling back to legacy.") - _do_merge_lora_legacy(cfg=cfg) + _do_merge_lora_efficient(cfg=cfg) def _do_merge_lora_legacy(*, cfg: DictDefault) -> None: From 6e0617ad66015017004614a488980905c0573250 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Wed, 24 Sep 2025 22:42:25 +0530 Subject: [PATCH 15/17] merge_method -> merge_lora_methods --- src/axolotl/utils/lora_merge_efficient.py | 7 ++----- src/axolotl/utils/schemas/peft.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/lora_merge_efficient.py b/src/axolotl/utils/lora_merge_efficient.py index f8025bba8b..2913d679d1 100644 --- a/src/axolotl/utils/lora_merge_efficient.py +++ b/src/axolotl/utils/lora_merge_efficient.py @@ -16,6 +16,7 @@ from tqdm import tqdm from axolotl.utils.logging import get_logger +from huggingface_hub import snapshot_download LOG = get_logger(__name__) @@ -100,7 +101,6 @@ def merge_lora_sharded_efficient( output_path = Path(output_path) if "/" in str(base_model_path) and not base_model_path.exists(): - from huggingface_hub import snapshot_download base_model_path = Path(snapshot_download(str(base_model_path))) @@ -181,10 +181,7 @@ def merge_lora_sharded_efficient( if lora_file.suffix == ".safetensors": lora_state = safetensors.torch.load_file(lora_file) else: - try: - lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614 - except TypeError: - lora_state = torch.load(lora_file, map_location="cpu") # nosec B614 + lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614 LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge") model_shards = get_model_shards(base_model_path) diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index e9846a8f15..1cb5acb2a8 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -140,7 +140,7 @@ class LoraConfig(BaseModel): ) merge_lora: bool | None = None - merge_method: Literal["legacy", "memory_efficient"] | None = Field( + merge_lora_method: Literal["legacy", "memory_efficient"] | None = Field( default="memory_efficient", json_schema_extra={ "description": "Method to use for LoRA merging. 'memory_efficient' (default) processes shards individually to reduce memory usage, 'legacy' loads the full model into memory." From dd2428627a5b6419169d06ee7cfeb8be4ba07810 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Thu, 25 Sep 2025 00:45:12 +0530 Subject: [PATCH 16/17] remove duplicate cal + safetensor + move to lora_merge.py --- src/axolotl/cli/merge_lora.py | 2 +- .../utils/lora_merge.py} | 116 ++++++++++-------- 2 files changed, 64 insertions(+), 54 deletions(-) rename src/axolotl/{utils/lora_merge_efficient.py => cli/utils/lora_merge.py} (74%) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index a7fe553c4f..7b6998863e 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -10,7 +10,7 @@ from axolotl.cli.utils import load_model_and_tokenizer from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger -from axolotl.utils.lora_merge_efficient import merge_lora_sharded_efficient +from axolotl.cli.utils.lora_merge import merge_lora_sharded_efficient LOG = get_logger(__name__) diff --git a/src/axolotl/utils/lora_merge_efficient.py b/src/axolotl/cli/utils/lora_merge.py similarity index 74% rename from src/axolotl/utils/lora_merge_efficient.py rename to src/axolotl/cli/utils/lora_merge.py index 2913d679d1..e358216642 100644 --- a/src/axolotl/utils/lora_merge_efficient.py +++ b/src/axolotl/cli/utils/lora_merge.py @@ -1,7 +1,3 @@ -""" -Memory-efficient LoRA merging implementation inspired by qlora-pipe. -Processes model shards individually without loading the full model into memory. -""" import gc import os @@ -85,6 +81,50 @@ def copy_non_model_files( shutil.copy2(filepath, output_path) +def _merge_tensor_with_lora( + tensor: torch.Tensor, + key: str, + lora_state: Dict[str, torch.Tensor], + scale: float, + lora_config_dict: Dict, + device: str, +) -> torch.Tensor: + """ + Helper function to merge a single tensor with its corresponding LoRA weights. + + Args: + tensor: Base model tensor + key: Tensor key/name + lora_state: Dictionary containing LoRA weights + scale: LoRA scaling factor (alpha/r) + lora_config_dict: LoRA configuration dictionary + device: Device to perform computations on + + Returns: + Merged tensor with LoRA applied + """ + lora_a, lora_b = find_lora_weights(lora_state, key) + + if lora_a is not None and lora_b is not None: + LOG.debug(f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}") + + original_dtype = tensor.dtype + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if bool( + lora_config_dict.get("fan_in_fan_out", False) + or lora_config_dict.get("lora_fan_in_fan_out", False) + ): + delta = delta.T + merged_tensor = (base_fp32 + delta).to(original_dtype).detach().cpu() + del base_fp32, a_fp32, b_fp32, delta + return merged_tensor, True + else: + return tensor.detach().cpu(), False + + def merge_lora_sharded_efficient( base_model_path: Union[str, Path], lora_adapter_path: Union[str, Path], @@ -101,7 +141,6 @@ def merge_lora_sharded_efficient( output_path = Path(output_path) if "/" in str(base_model_path) and not base_model_path.exists(): - base_model_path = Path(snapshot_download(str(base_model_path))) os.makedirs(output_path, exist_ok=True) @@ -206,70 +245,41 @@ def merge_lora_sharded_efficient( for key in f.keys(): total_tensors += 1 tensor = f.get_tensor(key) - lora_a, lora_b = find_lora_weights(lora_state, key) - - if lora_a is not None and lora_b is not None: + merged_tensor, was_merged = _merge_tensor_with_lora( + tensor, key, lora_state, scale, lora_config_dict, device + ) + merged_tensors[key] = merged_tensor + if was_merged: merged_count += 1 - LOG.debug( - f"Merging LoRA for {key}: {lora_a.shape}, {lora_b.shape}" - ) - - original_dtype = tensor.dtype - base_fp32 = tensor.to(device).to(torch.float32) - a_fp32 = lora_a.to(device).to(torch.float32) - b_fp32 = lora_b.to(device).to(torch.float32) - delta = scale * (b_fp32 @ a_fp32) - if bool( - lora_config_dict.get("fan_in_fan_out", False) - or lora_config_dict.get("lora_fan_in_fan_out", False) - ): - delta = delta.T - merged_tensors[key] = ( - (base_fp32 + delta).to(original_dtype).detach().cpu() - ) - del base_fp32, a_fp32, b_fp32, delta - else: - merged_tensors[key] = tensor.detach().cpu() else: state_dict = torch.load( # nosec B614: loading trusted model weights shard_path, map_location="cpu", weights_only=True ) for key, tensor in state_dict.items(): total_tensors += 1 - lora_a, lora_b = find_lora_weights(lora_state, key) - - if lora_a is not None and lora_b is not None: + merged_tensor, was_merged = _merge_tensor_with_lora( + tensor, key, lora_state, scale, lora_config_dict, device + ) + merged_tensors[key] = merged_tensor + if was_merged: merged_count += 1 - original_dtype = tensor.dtype - base_fp32 = tensor.to(device).to(torch.float32) - a_fp32 = lora_a.to(device).to(torch.float32) - b_fp32 = lora_b.to(device).to(torch.float32) - delta = scale * (b_fp32 @ a_fp32) - if bool( - lora_config_dict.get("fan_in_fan_out", False) - or lora_config_dict.get("lora_fan_in_fan_out", False) - ): - delta = delta.T - merged_tensors[key] = ( - (base_fp32 + delta).to(original_dtype).detach().cpu() - ) - del base_fp32, a_fp32, b_fp32, delta - else: - merged_tensors[key] = tensor.detach().cpu() output_shard_path = output_path / shard_path.name merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} - if shard_path.suffix == ".safetensors": + + if safe_tensors: + if not str(output_shard_path).endswith(".safetensors"): + output_shard_path = output_path / (shard_path.stem + ".safetensors") safetensors.torch.save_file( merged_tensors, output_shard_path, metadata=metadata ) else: - if safe_tensors: - LOG.warning( - "safe_tensors=True requested but input shards are .bin; preserving .bin format " - "to avoid index mismatches." + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file( + merged_tensors, output_shard_path, metadata=metadata ) - torch.save(merged_tensors, output_shard_path) + else: + torch.save(merged_tensors, output_shard_path) del merged_tensors if device != "cpu" and torch.cuda.is_available(): From d660e66c9a3aa40121cfb54c24cb5c2497280d44 Mon Sep 17 00:00:00 2001 From: ved1beta Date: Thu, 25 Sep 2025 00:46:50 +0530 Subject: [PATCH 17/17] lint --- src/axolotl/cli/merge_lora.py | 2 +- src/axolotl/cli/utils/lora_merge.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 7b6998863e..40504518c0 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -8,9 +8,9 @@ from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer +from axolotl.cli.utils.lora_merge import merge_lora_sharded_efficient from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger -from axolotl.cli.utils.lora_merge import merge_lora_sharded_efficient LOG = get_logger(__name__) diff --git a/src/axolotl/cli/utils/lora_merge.py b/src/axolotl/cli/utils/lora_merge.py index e358216642..74c94ed749 100644 --- a/src/axolotl/cli/utils/lora_merge.py +++ b/src/axolotl/cli/utils/lora_merge.py @@ -1,4 +1,3 @@ - import gc import os import shutil @@ -8,11 +7,11 @@ import safetensors import safetensors.torch import torch +from huggingface_hub import snapshot_download from peft import LoraConfig from tqdm import tqdm from axolotl.utils.logging import get_logger -from huggingface_hub import snapshot_download LOG = get_logger(__name__)