-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat:merge-lora iterate through bins without loading #3095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughAdds a shard-wise, memory-efficient LoRA merging utility and integrates it into the CLI with a dispatch that prefers the memory-efficient method (default) and falls back to the legacy in-memory merge on RuntimeError; also adds a Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes ✨ Finishing touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (12)
src/axolotl/utils/schemas/peft.py (2)
132-138: Enum looks good; consider documenting/deprecating the old boolean to avoid confusionIntroducing
merge_method: Literal["standard", "memory_efficient"]is a clear improvement. However,merge_lora: bool | Nonestill exists and is referenced elsewhere, which can be confusing to users. Recommend:
- Mark
merge_loraas deprecated in the field’s description and CLI help.- In docs/examples, prefer
merge_methodand explain when to setmerge_lora=True(still used by validators/CLI).
154-177: Validator still keys offmerge_lora; clarify interaction withmerge_method
validate_qloragates behavior onself.merge_lora. If a user configures onlymerge_method="memory_efficient"but forgetsmerge_lora=True, validation may incorrectly follow the "not merging" path. Sincedo_cli()currently forcesmerge_lora=True, this is fine for the CLI path, but configs used programmatically may drift.Option A (minimal): Document that
merge_loramust be True when invoking a merge, regardless ofmerge_method.Option B (preferred): Make the validator robust by checking an explicit “are we merging” signal that can be derived from the CLI invocation or
merge_methodwhen the merge entrypoint is called. If changing semantics is risky, emit a warning whenmerge_method != "standard"andmerge_lora is not True.src/axolotl/utils/lora_merge_efficient.py (7)
60-87: Skip index files and preserve metadata when copying; avoid needless large-file copies
- If
get_model_shardsmissed some shard patterns,copy_non_model_filescan end up copying large shard files only to overwrite them later.- Also skip index JSONs (
*.index.json) and usecopy2to preserve timestamps and metadata.@@ - for filepath in input_path.glob("*"): + for filepath in input_path.glob("*"): if filepath.is_dir(): continue if filepath.name in shard_names: continue + # Skip HF index files and other model indices + if filepath.name.endswith(".index.json"): + continue if filepath.suffix == ".gguf": continue if filepath.name.startswith("model") and filepath.suffix == ".safetensors": continue @@ - shutil.copy(filepath, output_path) + shutil.copy2(filepath, output_path)
100-109: Graceful fallback when CUDA is unavailableDefaulting to
"cuda"can fail on CPU-only hosts. Fall back to CPU and log once.@@ - base_model_path = Path(base_model_path) + base_model_path = Path(base_model_path) @@ - output_path = Path(output_path) + output_path = Path(output_path) + + if device == "cuda" and not torch.cuda.is_available(): + LOG.warning("CUDA not available; falling back to CPU for merge.") + device = "cpu"
115-119: Guard against missing/zero rank in adapter configDivision by zero or
Nonewill raise without a helpful message.- scale = lora_config.lora_alpha / lora_config.r + if not getattr(lora_config, "r", None): + raise ValueError("Invalid LoRA config: rank `r` is missing or zero.") + scale = lora_config.lora_alpha / lora_config.r
130-134: Compat:torch.load(weights_only=True)isn’t available on older TorchProvide a fallback for wider compatibility.
- else: - lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) + else: + try: + lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) + except TypeError: + # Fallback for older torch versions + lora_state = torch.load(lora_file, map_location="cpu")
154-181: Reduce GPU memory pressure; ensure CPU tensors before saving; handle fan_in_fan_out
- Reading tensors straight onto CUDA and accumulating them in
merged_tensorskeeps a whole shard on GPU; flip results back to CPU eagerly.- Apply
fan_in_fan_outif present to avoid incorrect orientation on models that require it.@@ - for key in f.keys(): + for key in f.keys(): total_tensors += 1 tensor = f.get_tensor(key) lora_a, lora_b = find_lora_weights(lora_state, key) @@ - delta = scale * ( - lora_b.to(torch.float32) @ lora_a.to(torch.float32) - ) - - merged_tensor = (tensor_fp32 + delta).to(original_dtype) - merged_tensors[key] = merged_tensor + delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32)) + # Handle fan_in_fan_out if present + if getattr(lora_config, "fan_in_fan_out", False): + delta = delta.T + merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype).cpu() else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.to("cpu")
182-199: Mirror CPU-offload approach for .bin shardsSame memory considerations should apply when shards are
.bin.@@ - for key, tensor in state_dict.items(): + for key, tensor in state_dict.items(): total_tensors += 1 lora_a, lora_b = find_lora_weights(lora_state, key) @@ - delta = scale * ( - lora_b.to(torch.float32) @ lora_a.to(torch.float32) - ) - merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype) + delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32)) + if getattr(lora_config, "fan_in_fan_out", False): + delta = delta.T + merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype).cpu() else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.to("cpu")
135-139: Optional: Avoid moving entire LoRA state to GPU upfrontMoving the whole LoRA state to CUDA may spike memory on large adapters. Consider keeping LoRA on CPU and moving just the A/B tensors for the tensor being merged in the loop (or using pinned memory). This is a targeted optimization and can be a follow-up.
src/axolotl/cli/merge_lora.py (3)
24-29: Validatemerge_methodearly and log the choiceGuard against typos and make behavior explicit.
- merge_method = getattr(cfg, "merge_method", "standard") + merge_method = getattr(cfg, "merge_method", "standard") + if merge_method not in ("standard", "memory_efficient"): + raise ValueError(f"Invalid merge_method: {merge_method!r}. Expected 'standard' or 'memory_efficient'.") + LOG.info("Selected LoRA merge method: %s", merge_method)
88-91: Nit: docstring typoUse
load_in_4bit(with underscore) for consistency with the actual config key.- (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.). + (`load_in_8bit=False`, `load_in_4bit=False`, `flash_attention=False`, etc.).
113-118: Support remote LoRA adapters (HF Hub) or improve error messageThe memory-efficient implementation currently requires a local adapter directory. Consider:
- Supporting Hub IDs by calling
snapshot_download(similar to base model logic), or- Clarifying the message to “LoRA adapter directory does not exist” to avoid implying the merged output should already exist.
- raise ValueError( - f"Target directory for LoRA merged model does not exist: `{parsed_cfg.lora_model_dir}`" - ) + raise ValueError( + f"LoRA adapter directory does not exist: `{parsed_cfg.lora_model_dir}`. " + "Provide a local path to the adapter (directory containing adapter_config.json)." + )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
src/axolotl/cli/merge_lora.py(3 hunks)src/axolotl/utils/lora_merge_efficient.py(1 hunks)src/axolotl/utils/schemas/peft.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
LoraConfig(28-191)src/axolotl/utils/logging.py (1)
get_logger(42-49)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
merge_lora_sharded_efficient(89-214)src/axolotl/utils/logging.py (1)
get_logger(42-49)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py
159-159: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: preview
- GitHub Check: pre-commit
src/axolotl/cli/merge_lora.py
Outdated
| merge_method = getattr(cfg, "merge_method", "standard") | ||
| if merge_method == "memory_efficient": | ||
| _do_merge_lora_efficient(cfg=cfg) | ||
| else: | ||
| _do_merge_lora_standard(cfg=cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's be opinionated and prefer the memory_efficient merge. the old "standard" version can be a manual fallback if users have issues. You could also do something like wrap the _do_merge_lora_efficient in a try/catch and if it fails, provide a hint to the user to use the standard option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe rename standard to legacy or something else since it's no longer axolotl's "standard"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to consider is that this only works for standard LoRA, and not other advanced methods like DoRA
Co-authored-by: Wing Lian <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
|
curious if you have any numbers on how much peak VRAM is saved? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
src/axolotl/utils/lora_merge_efficient.py (3)
22-43: Fix key parsing and matching: strip(".weight") is incorrect; substring matching risks false positives
- strip(".weight") removes any of the characters in ".weight" from both ends, corrupting keys.
- Matching with
if clean_key in lora_keycan attach the wrong adapters when names share substrings.Apply an exact-suffix removal and precise endswith-based matching, and bail early once both A/B are found:
@@ -def find_lora_weights( - lora_state: Dict[str, torch.Tensor], key: str -) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: +def find_lora_weights( + lora_state: Dict[str, torch.Tensor], key: str +) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: @@ - clean_key = key.strip(".weight") + # Remove only the exact ".weight" suffix if present + clean_key = key[:-7] if key.endswith(".weight") else key clean_key = re.sub(r"^(base_model\.model\.|language_model\.)", "", clean_key) @@ - for lora_key, lora_weight in lora_state.items(): - if clean_key in lora_key: - if "lora_A" in lora_key: - lora_a = lora_weight - elif "lora_B" in lora_key: - lora_b = lora_weight + suffixes_a = (f"{clean_key}.lora_A.weight", f"{clean_key}.lora_A.default.weight") + suffixes_b = (f"{clean_key}.lora_B.weight", f"{clean_key}.lora_B.default.weight") + for lora_key, lora_weight in lora_state.items(): + if any(lora_key.endswith(sfx) for sfx in suffixes_a): + lora_a = lora_weight + elif any(lora_key.endswith(sfx) for sfx in suffixes_b): + lora_b = lora_weight + if lora_a is not None and lora_b is not None: + break
200-209: Do not rename .bin shards to .safetensors; always save CPU tensors; preserve original formatRenaming
.binto.safetensorswhile usingtorch.saveproduces invalid files and breaks index JSONs. Also, both safetensors and torch.save should receive CPU tensors.Apply this fix:
- output_shard_path = output_path / shard_path.name - if safe_tensors and shard_path.suffix == ".safetensors": - safetensors.torch.save_file( - merged_tensors, output_shard_path, metadata=metadata - ) - else: - if safe_tensors: - output_shard_path = output_shard_path.with_suffix(".safetensors") - torch.save(merged_tensors, output_shard_path) + output_shard_path = output_path / shard_path.name + # Ensure CPU tensors before writing + merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata) + else: + # Preserve .bin format to keep HF index consistency + torch.save(merged_tensors, output_shard_path)
46-58: Runtime error: listPath is not a constructor; leave patterns as HF-standard
list[Path]()will raise at runtime. Initialize with a literal/list() instead. Patterns look good per transformers conventions (pytorch_model*.bin, model*.safetensors).def get_model_shards(model_path: Path) -> list[Path]: """Find all model shards in the given path.""" - shards = list[Path]() + shards: list[Path] = [] @@ - patterns = ["model*.safetensors", "model*.bin", "pytorch_model*.bin"] + patterns = ["model*.safetensors", "model*.bin", "pytorch_model*.bin"]
🧹 Nitpick comments (7)
src/axolotl/utils/lora_merge_efficient.py (7)
120-134: Defensive load for older Torch versions (weights_only not available everywhere)
torch.load(..., weights_only=True)is not present in all Torch versions and may raiseTypeError.Wrap with a compatibility fallback:
- if lora_file.suffix == ".safetensors": - lora_state = safetensors.torch.load_file(lora_file) - else: - lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) + if lora_file.suffix == ".safetensors": + lora_state = safetensors.torch.load_file(lora_file) + else: + try: + lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) + except TypeError: + # Torch < 2.3 compatibility + lora_state = torch.load(lora_file, map_location="cpu")If CI uses an older Torch, this prevents a hard failure.
135-139: Optional: avoid transferring LoRA weights to GPU unless necessaryMoving the entire LoRA state to GPU can be costly and unnecessary if shards are read on CPU for save. If GPU memory is tight, keep LoRA on CPU and cast selectively when computing deltas.
- if device != "cpu": + if device != "cpu": LOG.info(f"Moving LoRA weights to {device}") for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"): lora_state[key] = value.to(device)Consider gating this by a config flag (e.g., lora_on_cpu for this path defaulting to True) or only moving
lora_a/lora_bon demand inside the merge loop.
182-199: torch.load on .bin shards: keep tensors on CPU if you intend to save CPU tensorsYou map to
device, which could be CUDA, but safetensors expects CPU tensors on save. You address this later—see next comment. For clarity and lower VRAM pressure, considermap_location="cpu"here and move per-tensor only for math if you truly need GPU.- state_dict = torch.load( - shard_path, map_location=device - ) # nosec B614: loading trusted model weights + state_dict = torch.load( + shard_path, map_location="cpu" + ) # nosec B614: loading trusted model weights
211-213: Guard CUDA-only cache clearing to avoid errors on non-CUDA devicesCalling
torch.cuda.empty_cache()when CUDA is unavailable or when device is not CUDA can error on some setups.- if device != "cpu": - torch.cuda.empty_cache() + if isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available(): + torch.cuda.empty_cache()
111-118: Verify LoraConfig loader API; provide fallback for environments without from_json_fileSome PEFT versions don’t expose
LoraConfig.from_json_file. If that’s the case in your CI, parse JSON and constructLoraConfigdirectly, or useLoraConfig.from_pretrained(lora_adapter_path).- lora_config = LoraConfig.from_json_file(config_file) + try: + lora_config = LoraConfig.from_json_file(config_file) # type: ignore[attr-defined] + except AttributeError: + import json + with open(config_file) as f: + cfg = json.load(f) + lora_config = LoraConfig(**cfg)
150-178: Performance: avoid O(N_params × N_lora_tensors) scanning by pre-indexing LoRACurrent approach scans all LoRA entries for every model tensor. Pre-index once into a dict of base_key → (A, B) to reduce time on large models.
I can provide a follow-up patch to build an index like:
- Parse lora_state keys, normalize (strip prefixes), map base_key → (A, B).
- In the loop, just
lookup = lora_index.get(clean_key).Want me to draft this refactor?
96-109: Remote path detection is heuristic; ok for nowThe slash check is pragmatic. If you later add support for local paths with slashes that don’t exist yet, consider a more explicit “is HF repo id” flag or try/except around snapshot_download.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/axolotl/utils/lora_merge_efficient.py(1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.434Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (1)
src/axolotl/utils/lora_merge_efficient.py (1)
src/axolotl/utils/logging.py (1)
get_logger(42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py
159-159: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (5)
src/axolotl/utils/lora_merge_efficient.py (5)
150-181: Numerical safety: ensure delta and base cast paths are consistentYou correctly compute in fp32 then cast back. Consider also handling mismatched device/dtype gracefully (e.g., bf16 base). Current code already promotes to float32, which is fine. No action required—just noting it’s correct.
60-87: Copy rules look sane; confirm index files are preservedThe logic preserves
pytorch_model*.bin.index.jsonandmodel*.safetensors.index.json, and skips GGUF and shard payloads. This is what we want for HF repos.Please double-check that repos with non-standard component subfolders (e.g., text_encoder/, vision_tower/) are handled upstream by selecting the correct base path; this helper only handles top-level files.
140-146: Good: shard discovery + non-model copy before mergeFlow is sound: discover shards → copy non-model artifacts → iterate shards. Nice.
118-118: Nice: logging scale factor aids reproducibilityGood observability touch.
214-215: Good final summary logSimple, actionable merge summary.
|
benchmarks coming soon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
src/axolotl/utils/lora_merge_efficient.py (3)
27-37: Fix suffix removal: rstrip(".weight") is incorrect and can corrupt keysrstrip removes any trailing characters present in the set {'.', 'w', 'e', 'i', 'g', 'h', 't'} rather than the exact suffix, which can break LoRA key matching and lead to silent mis-merges or missed merges.
Apply this diff to remove only the exact ".weight" suffix:
- clean_key = key.rstrip(".weight") + # Remove only the exact ".weight" suffix + clean_key = key[:-7] if key.endswith(".weight") else key
43-55: Initialize list correctly; currentlist[Path]()raises at runtime
list[Path]()is a subscripted type hint, not a constructor. This will raise a TypeError at runtime.Use a standard list initialization with an optional type hint:
-def get_model_shards(model_path: Path) -> list[Path]: +def get_model_shards(model_path: Path) -> list[Path]: """Find all model shards in the given path.""" - shards = list[Path]() + shards: list[Path] = [] patterns = ["model*.safetensors", "pytorch_model*.bin"]
197-206: Preserve original shard format; don’t emit fake “.safetensors” from torch.save; ensure CPU tensors before savingRenaming
.binto.safetensorswhile still usingtorch.saveproduces invalid safetensors files and breaks HF index mappings. Also, ensure tensors are on CPU when writing.Apply this diff:
- output_shard_path = output_path / shard_path.name - if safe_tensors and shard_path.suffix == ".safetensors": - safetensors.torch.save_file( - merged_tensors, output_shard_path, metadata=metadata - ) - else: - if safe_tensors: - output_shard_path = output_shard_path.with_suffix(".safetensors") - torch.save(merged_tensors, output_shard_path) + output_shard_path = output_path / shard_path.name + # Ensure CPU tensors before writing + merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata) + else: + if safe_tensors: + LOG.warning( + "safe_tensors=True requested but base shards are .bin; preserving .bin format to keep index consistent." + ) + torch.save(merged_tensors, output_shard_path)src/axolotl/cli/merge_lora.py (1)
80-91: Pass device explicitly to support CPU-only hosts; avoid defaulting to CUDAThe helper defaults to
"cuda". On CPU-only machines this will fail before merging, forcing an unnecessary fallback to legacy. Detect availability and pass an explicit device.Apply this diff:
LOG.info("Using memory-efficient LoRA merging method...") output_path = Path(cfg.output_dir) / "merged" safe_tensors = getattr(cfg, "save_safetensors", True) + # Select device: prefer CUDA if available, otherwise CPU + try: + import torch + device = "cuda" if torch.cuda.is_available() else "cpu" + except Exception: # pragma: no cover + device = "cpu" + # Perform memory-efficient merge merge_lora_sharded_efficient( base_model_path=cfg.base_model, lora_adapter_path=cfg.lora_model_dir, output_path=output_path, - safe_tensors=safe_tensors, + safe_tensors=safe_tensors, + device=device, )Note: As per the team learning, we’re not tying this to
lora_on_cpu(that flag is only relevant for full-model loading), just honoring hardware availability.
🧹 Nitpick comments (3)
src/axolotl/utils/lora_merge_efficient.py (2)
112-116: Guard against invalid LoRA configs (r == 0) and log contextDefensive check: if
lora_config.ris 0, division fails and merge scale is undefined.Apply this small guard:
- scale = lora_config.lora_alpha / lora_config.r + if not getattr(lora_config, "r", None): + raise ValueError(f"Invalid LoRA config: r={getattr(lora_config, 'r', None)}") + scale = lora_config.lora_alpha / lora_config.r
156-160: Optional micro-optimization: short-circuit once both A and B foundEarly-exiting the loop when both weights are found avoids iterating the entire
lora_stateon every tensor key (benefits large adapters).Example:
for lora_key, lora_weight in lora_state.items(): if lora_key.endswith(f"{clean_key}.lora_A.weight"): lora_a = lora_weight elif lora_key.endswith(f"{clean_key}.lora_B.weight"): lora_b = lora_weight + if lora_a is not None and lora_b is not None: + breaksrc/axolotl/cli/merge_lora.py (1)
24-36: Broaden fallback and clarify user guidanceCatching only
RuntimeErrormay miss common failure modes (e.g.,FileNotFoundErrorfor missing shards/config). Consider broadening the exception set and keep the clear fallback message.Example:
- else: - try: - _do_merge_lora_efficient(cfg=cfg) - except RuntimeError as e: + else: + try: + _do_merge_lora_efficient(cfg=cfg) + except (RuntimeError, FileNotFoundError, OSError, ValueError) as e: LOG.error(f"Memory-efficient merge failed: {e}") LOG.info("Falling back to legacy merge method...") _do_merge_lora_legacy(cfg=cfg)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
src/axolotl/cli/merge_lora.py(3 hunks)src/axolotl/utils/lora_merge_efficient.py(1 hunks)src/axolotl/utils/schemas/peft.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/utils/schemas/peft.py
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.434Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
Applied to files:
src/axolotl/cli/merge_lora.pysrc/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (2)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
merge_lora_sharded_efficient(86-211)src/axolotl/utils/logging.py (1)
get_logger(42-49)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
src/axolotl/utils/lora_merge_efficient.py (1)
src/axolotl/utils/logging.py (1)
get_logger(42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py
156-156: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: preview
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/utils/lora_merge_efficient.py (2)
151-178: Avoid VRAM growth: store CPU tensors inmerged_tensorsduring safetensors pathCurrently, tensors are read on
device(often CUDA) and kept on GPU insidemerged_tensorsuntil the shard is saved, causing peak VRAM to scale with shard size. Convert to CPU before adding tomerged_tensorsto keep GPU usage bounded to a single tensor plus LoRA weights.
[uggest_essential_refactor]
Apply this minimal-change diff:- merged_tensor = (tensor_fp32 + delta).to(original_dtype) - merged_tensors[key] = merged_tensor + merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype).detach().cpu() else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.detach().cpu()Note: Keeping results on CPU preserves memory-efficiency without changing compute placement. If desired, we can also open safetensors with
device="cpu"and temporarily move tensors todeviceonly for the addition; happy to provide that variant.
156-156: Ignore Ruff SIM118 hint hereThe linter hint about using membership on dict instead of
.keys()is not applicable;fis asafetensors.safe_openhandle, not a dict. Iteratingf.keys()is the intended API.
|
@ved1beta could you also ensure the weights/logits produced by a model which was merged using the legacy vs. memory efficient method are identical? |
|
this should be insured with the test run ?
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (4)
src/axolotl/utils/lora_merge_efficient.py (4)
115-115: Nice: Helpful logging addedScale factor, adapter path, and shard count logs improve debuggability and UX.
Also applies to: 125-126, 141-143
27-37: Bug: rstrip(".weight") is incorrect and corrupts keys
rstrip(".weight")removes any trailing characters in the set {'.', 'w', 'e', 'i', 'g', 'h', 't'}, not the exact suffix. This will cause mismatched/missed LoRA key lookups.Apply this minimal fix:
- clean_key = key.rstrip(".weight") + # Remove only the exact ".weight" suffix + clean_key = key[:-7] if key.endswith(".weight") else key
43-55: Runtime error:list[Path]()is not a constructorThis raises at runtime. Initialize a list instead.
- shards = list[Path]() + shards: list[Path] = []
197-206: Critical: do not rename .bin shards to .safetensors; ensure CPU tensors before savingRenaming
.binto.safetensorswhile callingtorch.saveproduces invalid safetensors and breaks HF index files. Preserve the original shard extension and only usesafetensors.torch.save_filefor.safetensorsinputs. Always save CPU tensors.- output_shard_path = output_path / shard_path.name - if safe_tensors and shard_path.suffix == ".safetensors": - safetensors.torch.save_file( - merged_tensors, output_shard_path, metadata=metadata - ) - else: - if safe_tensors: - output_shard_path = output_shard_path.with_suffix(".safetensors") - torch.save(merged_tensors, output_shard_path) + output_shard_path = output_path / shard_path.name + # Ensure CPU tensors before writing + merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata) + else: + # Preserve .bin format to keep HF indices valid + torch.save(merged_tensors, output_shard_path)
🧹 Nitpick comments (2)
src/axolotl/utils/lora_merge_efficient.py (2)
21-41: Avoid O(N×M) scans of lora_state for every tensorScanning the entire
lora_stateper tensor is quadratic and slow on large models. Pre-index LoRA A/B weights once, then do O(1) lookups.Example approach (new helper and usage):
# New helper (place near find_lora_weights) def build_lora_index(lora_state: Dict[str, torch.Tensor]) -> dict[str, tuple[torch.Tensor|None, torch.Tensor|None]]: index: dict[str, tuple[Optional[torch.Tensor], Optional[torch.Tensor]]] = {} for k, v in lora_state.items(): if k.endswith(".lora_A.weight"): base = k[:-len(".lora_A.weight")] a, b = index.get(base, (None, None)) index[base] = (v, b) elif k.endswith(".lora_B.weight"): base = k[:-len(".lora_B.weight")] a, b = index.get(base, (None, None)) index[base] = (a, v) return indexThen replace
find_lora_weights(lora_state, key)with lookups like:base = key[:-7] if key.endswith(".weight") else key lora_a, lora_b = lora_index.get(base, (None, None))
90-92: Default device "cuda" is risky for a memory-efficient pathDefaulting to GPU can surprise users and increase VRAM usage. Consider defaulting to
"cpu"and allowing callers to opt-in to a GPU device.- device: str = "cuda", + device: str = "cpu",
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/axolotl/utils/lora_merge_efficient.py(1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.434Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (1)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
LoraConfig(28-191)src/axolotl/utils/logging.py (1)
get_logger(42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py
156-156: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/lora_merge_efficient.py (2)
47-53: Patterns for HF shards look correctUsing
"model*.safetensors"and"pytorch_model*.bin"aligns with HF conventions and will discover both single-file and sharded checkpoints.
68-84: Validation complete – all index JSON references are valid
- Ran the provided validation script against both
model.safetensors.index.jsonandpytorch_model.bin.index.jsoninmerged_out; no missing shard files were reported.- The copy logic in
lora_merge_efficient.py(lines 68–84) correctly skips model shards and.gguffiles while preserving all other artifacts (e.g., tokenizer and config files).With index consistency confirmed, no further changes are needed here.
| lora_config = LoraConfig.from_json_file(config_file) | ||
| scale = lora_config["lora_alpha"] / lora_config["r"] | ||
|
|
||
| LOG.info(f"LoRA scale factor: {scale}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: LoraConfig fields accessed like a dict
LoraConfig.from_json_file returns an object; lora_config["lora_alpha"] will fail. Use attributes. Also guard against zero r.
- lora_config = LoraConfig.from_json_file(config_file)
- scale = lora_config["lora_alpha"] / lora_config["r"]
+ lora_config = LoraConfig.from_json_file(config_file)
+ if not getattr(lora_config, "r", None):
+ raise ValueError("LoRA config 'r' must be > 0")
+ scale = float(lora_config.lora_alpha) / float(lora_config.r)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| lora_config = LoraConfig.from_json_file(config_file) | |
| scale = lora_config["lora_alpha"] / lora_config["r"] | |
| LOG.info(f"LoRA scale factor: {scale}") | |
| lora_config = LoraConfig.from_json_file(config_file) | |
| # Ensure 'r' is present and non-zero to avoid division by zero | |
| if not getattr(lora_config, "r", None): | |
| raise ValueError("LoRA config 'r' must be > 0") | |
| scale = float(lora_config.lora_alpha) / float(lora_config.r) | |
| LOG.info(f"LoRA scale factor: {scale}") |
🤖 Prompt for AI Agents
In src/axolotl/utils/lora_merge_efficient.py around lines 112 to 115, the code
treats LoraConfig as a dict (lora_config["lora_alpha"]) which will raise since
LoraConfig.from_json_file returns an object; change to use attribute access
(e.g., lora_config.lora_alpha and lora_config.r) and add a guard to prevent
division by zero (check if lora_config.r is truthy/greater than zero before
dividing, otherwise set scale to a safe default or raise a clear error), then
log the computed scale.
| delta = scale * ( | ||
| lora_b.to(torch.float32) @ lora_a.to(torch.float32) | ||
| ) | ||
|
|
||
| merged_tensor = (tensor_fp32 + delta).to(original_dtype) | ||
| merged_tensors[key] = merged_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Support fan_in_fan_out orientation from config
If lora_config.fan_in_fan_out (or lora_fan_in_fan_out) is enabled, a transpose is required to match PEFT’s application. Current code ignores this and may produce incorrect merges for such adapters.
Minimal patch sketch inside both safetensors and .bin branches:
- delta = scale * (b_fp32 @ a_fp32)
+ delta = scale * (b_fp32 @ a_fp32)
+ if getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False):
+ delta = delta.TPlease confirm the exact field name used by your PEFT version and add tests for a fan-in-fan-out adapter.
Also applies to: 188-194
🏁 Script executed:
#!/bin/bash
set -e
# Print context around the LoRA merge logic (lines ~150–220)
sed -n '150,220p' src/axolotl/utils/lora_merge_efficient.py
# Locate any definitions or imports of LoraConfig
rg -n "class.*LoraConfig" -g src
rg -n "import.*LoraConfig" -g src
# Check for existing handling of fan_in_fan_out flags
rg -n "fan_in_fan_out" -g src
rg -n "lora_fan_in_fan_out" -g src || trueLength of output: 2842
🏁 Script executed:
#!/bin/bash
set -e
# Print file header and initial context to locate function signature and imports
sed -n '1,200p' src/axolotl/utils/lora_merge_efficient.py
# Search for any mention of fan_in_fan_out in the entire repository
rg -n "fan_in_fan_out" .
# Search for LoraConfig class or imports across the codebase
rg -n "LoraConfig" .
# Search for usage of lora_config within this file
rg -n "lora_config" -n src/axolotl/utils/lora_merge_efficient.pyLength of output: 8450
Support LoRA fan_in_fan_out orientation in merge
The current merge logic in src/axolotl/utils/lora_merge_efficient.py always computes
delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))without accounting for the fan_in_fan_out flag in the PEFT config, which will result in incorrect merges when adapters were trained with fan_in_fan_out=True.
Please apply the following change in both the .safetensors branch (around lines 170–175) and the .bin branch (around lines 188–194):
- delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))
+ delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))
+ if lora_config.fan_in_fan_out:
+ delta = delta.T• Locations to update:
- safetensors loop (after line 170)
- torch.load loop (after line 188)
• Add a unit test with a LoRA adapter configured as fan_in_fan_out=True to verify the transpose is applied correctly.
🤖 Prompt for AI Agents
In src/axolotl/utils/lora_merge_efficient.py around lines 170–175 (safetensors
branch) and around lines 188–194 (torch.load/.bin branch), the merge always
computes delta as scale * (lora_b @ lora_a) and ignores the PEFT config flag
fan_in_fan_out; update both locations to check the adapter config and, when
fan_in_fan_out is True, transpose lora_a and lora_b appropriately (e.g., swap or
transpose operands so multiplication reflects the trained orientation) before
computing delta, then cast back to original dtype as now; also add a unit test
that loads/creates a LoRA adapter with fan_in_fan_out=True, runs the merge, and
asserts the merged tensor matches the expected result when the transpose branch
is applied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (8)
src/axolotl/utils/lora_merge_efficient.py (7)
27-40: Fix incorrect suffix removal and enforce exact LoRA key matching
rstrip(".weight")removes any trailing characters in the set {".weight"} and can corrupt keys; use an exact suffix check. Also keep precise endswith patterns for A/B weights.- clean_key = key.rstrip(".weight") + # Remove only the exact ".weight" suffix + clean_key = key[:-7] if key.endswith(".weight") else key - for lora_key, lora_weight in lora_state.items(): - if lora_key.endswith(f"{clean_key}.lora_A.weight"): - lora_a = lora_weight - elif lora_key.endswith(f"{clean_key}.lora_B.weight"): - lora_b = lora_weight + for lora_key, lora_weight in lora_state.items(): + if lora_key.endswith(f"{clean_key}.lora_A.weight"): + lora_a = lora_weight + elif lora_key.endswith(f"{clean_key}.lora_B.weight"): + lora_b = lora_weightAlso applies to: 32-37
43-55: Initialize shards list correctly; fix runtime error
list[Path]()is a type subscription, not a constructor; it will throw at runtime.- shards = list[Path]() + shards: list[Path] = []
112-116: AccessLoraConfigfields via attributes and guard division by zero
LoraConfig.from_json_filereturns an object; indexing like a dict will fail. Also validater > 0.- lora_config = LoraConfig.from_json_file(config_file) - scale = lora_config["lora_alpha"] / lora_config["r"] + lora_config = LoraConfig.from_json_file(config_file) + if not getattr(lora_config, "r", None): + raise ValueError("LoRA config 'r' must be > 0") + scale = float(lora_config.lora_alpha) / float(lora_config.r)
132-136: Avoid VRAM spikes: don’t bulk-move all LoRA tensors to GPUMoving the entire LoRA state to GPU defeats the “memory-efficient” goal and can OOM on small GPUs. Keep LoRA tensors on CPU and move per-tensor during merge.
- if device != "cpu": - LOG.debug(f"Moving LoRA weights to {device}") - for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"): - lora_state[key] = value.to(device) + LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge")
151-176: Load safetensors on CPU; per-tensor compute on device; support fan_in_fan_out; store results on CPU
safe_open(..., device=device)may load tensors directly to GPU; use CPU and JIT-move for compute.- Honor
fan_in_fan_outorientation when present in the config.- Ensure merged tensors are on CPU before serialization.
- if shard_path.suffix == ".safetensors": - with safetensors.safe_open(shard_path, framework="pt", device=device) as f: + if shard_path.suffix == ".safetensors": + # Always open on CPU to minimize VRAM; move per-tensor as needed + with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f: if hasattr(f, "metadata") and f.metadata(): metadata = f.metadata() @@ - for key in f.keys(): + for key in f.keys(): total_tensors += 1 - tensor = f.get_tensor(key) + tensor = f.get_tensor(key) # CPU tensor lora_a, lora_b = find_lora_weights(lora_state, key) @@ - if lora_a is not None and lora_b is not None: + if lora_a is not None and lora_b is not None: merged_count += 1 @@ - original_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - - delta = scale * ( - lora_b.to(torch.float32) @ lora_a.to(torch.float32) - ) - - merged_tensor = (tensor_fp32 + delta).to(original_dtype) - merged_tensors[key] = merged_tensor + original_dtype = tensor.dtype + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False): + delta = delta.T + merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu() else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.detach().cpu()
179-196: Load .bin shards on CPU; compute on device; store results on CPULoad state dict on CPU with
weights_only=Trueand JIT-move to device for compute to avoid unnecessary VRAM usage.- else: - state_dict = torch.load( - shard_path, map_location=device - ) # nosec B614: loading trusted model weights + else: + state_dict = torch.load( # nosec B614: loading trusted model weights + shard_path, map_location="cpu", weights_only=True + ) for key, tensor in state_dict.items(): total_tensors += 1 lora_a, lora_b = find_lora_weights(lora_state, key) @@ - if lora_a is not None and lora_b is not None: + if lora_a is not None and lora_b is not None: merged_count += 1 original_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - delta = scale * ( - lora_b.to(torch.float32) @ lora_a.to(torch.float32) - ) - merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype) + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False): + delta = delta.T + merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu() else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.detach().cpu()
197-206: Do not rename .bin shards to .safetensors; ensure CPU tensors before writingRenaming
.binto.safetensorswhile callingtorch.saveproduces invalid safetensors and breaks HF index files. Preserve the original shard format; always write CPU tensors; attach metadata only for safetensors.- output_shard_path = output_path / shard_path.name - if safe_tensors and shard_path.suffix == ".safetensors": - safetensors.torch.save_file( - merged_tensors, output_shard_path, metadata=metadata - ) - else: - if safe_tensors: - output_shard_path = output_shard_path.with_suffix(".safetensors") - torch.save(merged_tensors, output_shard_path) + output_shard_path = output_path / shard_path.name + # Ensure CPU tensors for serialization + merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata) + else: + if safe_tensors: + LOG.warning( + "safe_tensors=True requested but input shards are .bin; preserving .bin format to avoid index mismatches." + ) + torch.save(merged_tensors, output_shard_path)src/axolotl/cli/merge_lora.py (1)
80-91: Pass device explicitly to support CPU-only hostsThe efficient helper defaults to
"cuda". On CPU-only machines this will raise. Detect CUDA availability and pass a device argument.- # Perform memory-efficient merge - merge_lora_sharded_efficient( + # Choose device: prefer CUDA if available, otherwise CPU + try: + import torch + has_cuda = torch.cuda.is_available() + except Exception: + has_cuda = False + device = "cuda" if has_cuda else "cpu" + + # Perform memory-efficient merge + merge_lora_sharded_efficient( base_model_path=cfg.base_model, lora_adapter_path=cfg.lora_model_dir, output_path=output_path, safe_tensors=safe_tensors, + device=device, )
🧹 Nitpick comments (1)
src/axolotl/utils/lora_merge_efficient.py (1)
68-84: Optional: copy full directory tree (minus model shards) to preserve ancillary assetsCurrent implementation only copies top-level files. Some models include assets in subfolders (e.g., processors). Consider a recursive copy that excludes shard files by name.
Happy to propose a filtered copy routine using shutil.copytree with an ignore function that excludes shard files and GGUF.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
src/axolotl/cli/merge_lora.py(3 hunks)src/axolotl/utils/lora_merge_efficient.py(1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.434Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
Applied to files:
src/axolotl/cli/merge_lora.pysrc/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (2)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
merge_lora_sharded_efficient(86-211)src/axolotl/utils/logging.py (1)
get_logger(42-49)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
src/axolotl/utils/lora_merge_efficient.py (1)
src/axolotl/utils/logging.py (1)
get_logger(42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py
156-156: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (4)
src/axolotl/utils/lora_merge_efficient.py (2)
141-146: Progress counters are helpful; consider logging merge ratio at end (already present)The final summary log looks good and aligns with expectations for shard-wise merging.
211-211: Nice finishing logClear, actionable summary for users. LGTM.
src/axolotl/cli/merge_lora.py (2)
38-71: Legacy path looks consistentLoading, merging, dtype cast guard, and saves are in line with the existing behavior. Minor logs are helpful. LGTM.
123-128: Helpful validation for lora_model_dirClear error message; good UX.
|
Were you able to train a lora, and then merge using both the legacy and memory efficient methods to verify identical merged weights from both methods? |
|
yes tried merging and everything as you mentioned earlier here is the taining output i have a claude generated script for testing identical model weights it passes for the given checkpoint generated from the training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
src/axolotl/utils/lora_merge_efficient.py (5)
27-36: Fix suffix removal: rstrip is incorrect; use exact suffix logic
rstrip(".weight")removes any of those chars from the right end and can corrupt keys, breaking LoRA key matching.Apply:
- clean_key = key.rstrip(".weight") + # Remove only the exact ".weight" suffix + clean_key = key[:-7] if key.endswith(".weight") else key
43-49: Initialize list correctly; current code raises at runtime
list[Path]()is not a constructor call.Apply:
- shards = list[Path]() + shards: list[Path] = []
112-116: AccessLoraConfigattributes (not dict); guard divide-by-zero
LoraConfig.from_json_filereturns an object; dict-style access will fail. Also handler == 0/None.Apply:
- lora_config = LoraConfig.from_json_file(config_file) - scale = lora_config["lora_alpha"] / lora_config["r"] + lora_config = LoraConfig.from_json_file(config_file) + if not getattr(lora_config, "r", None): + raise ValueError("LoRA config 'r' must be > 0") + scale = float(lora_config.lora_alpha) / float(lora_config.r)
198-206: Do not write PyTorch pickles with.safetensorsextension; preserve original formatCurrent logic renames
.binshards to.safetensorsbut still usestorch.save, corrupting outputs and breaking index files.Apply:
- output_shard_path = output_path / shard_path.name - if safe_tensors and shard_path.suffix == ".safetensors": - safetensors.torch.save_file( - merged_tensors, output_shard_path, metadata=metadata - ) - else: - if safe_tensors: - output_shard_path = output_shard_path.with_suffix(".safetensors") - torch.save(merged_tensors, output_shard_path) + output_shard_path = output_path / shard_path.name + # Ensure CPU tensors before writing + merged_tensors = { + k: (v.detach().cpu() if isinstance(v, torch.Tensor) else v) + for k, v in merged_tensors.items() + } + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata) + else: + if safe_tensors: + LOG.warning( + "safe_tensors=True requested but input shards are .bin; preserving .bin format to avoid index mismatches." + ) + torch.save(merged_tensors, output_shard_path)
176-179: Safer deserialization: useweights_only=Trueand keep CPUAvoids unnecessary pickle deserialization semantics.
Apply:
- state_dict = torch.load( - shard_path, map_location="cpu" - ) # nosec B614: loading trusted model weights + state_dict = torch.load( # nosec B614: loading trusted model weights + shard_path, map_location="cpu", weights_only=True + )src/axolotl/cli/merge_lora.py (1)
24-36: Normalizemerge_method; acceptstandard; broaden fallback and log exceptionImproves UX (PR text mentions “standard”) and ensures fallback triggers on any failure. Also addresses TRY400.
Apply:
- merge_method = getattr(cfg, "merge_method", "memory_efficient") - LOG.info(f"Using {merge_method} LoRA merge method") - - if merge_method == "legacy": - _do_merge_lora_legacy(cfg=cfg) - else: - try: - _do_merge_lora_efficient(cfg=cfg) - except RuntimeError as e: - LOG.error(f"Memory-efficient merge failed: {e}") - LOG.info("Falling back to legacy merge method...") - _do_merge_lora_legacy(cfg=cfg) + merge_method = str(getattr(cfg, "merge_method", "memory_efficient")).lower().replace("-", "_") + if merge_method in {"legacy", "standard"}: + LOG.info("Using legacy LoRA merge method...") + _do_merge_lora_legacy(cfg=cfg) + else: + LOG.info("Using memory-efficient LoRA merge method...") + try: + _do_merge_lora_efficient(cfg=cfg) + except Exception: + LOG.exception("Memory-efficient merge failed") + LOG.info("Falling back to legacy merge method...") + _do_merge_lora_legacy(cfg=cfg)
🧹 Nitpick comments (2)
src/axolotl/utils/lora_merge_efficient.py (1)
86-92:deviceparameter is unused; either remove it or use it per-tensorCurrently flagged by Ruff (ARG001). Since merges run on CPU for memory efficiency, simplest is to drop it.
Apply (and update call sites):
-def merge_lora_sharded_efficient( +def merge_lora_sharded_efficient( base_model_path: Union[str, Path], lora_adapter_path: Union[str, Path], output_path: Union[str, Path], - device: str = "cpu", safe_tensors: bool = True, ) -> None:src/axolotl/cli/merge_lora.py (1)
86-92: If droppingdevicein the helper, remove it here tooKeeps signatures consistent.
- merge_lora_sharded_efficient( + merge_lora_sharded_efficient( base_model_path=cfg.base_model, lora_adapter_path=cfg.lora_model_dir, output_path=output_path, - device="cpu", safe_tensors=safe_tensors, )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
src/axolotl/cli/merge_lora.py(3 hunks)src/axolotl/utils/lora_merge_efficient.py(1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
Applied to files:
src/axolotl/cli/merge_lora.pysrc/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (2)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
merge_lora_sharded_efficient(86-213)src/axolotl/utils/logging.py (1)
get_logger(42-49)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
LoraConfig(28-191)src/axolotl/utils/logging.py (1)
get_logger(42-49)
🪛 Ruff (0.12.2)
src/axolotl/cli/merge_lora.py
33-33: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
src/axolotl/utils/lora_merge_efficient.py
90-90: Unused function argument: device
(ARG001)
110-110: Avoid specifying long messages outside the exception class
(TRY003)
121-123: Avoid specifying long messages outside the exception class
(TRY003)
136-136: Avoid specifying long messages outside the exception class
(TRY003)
🪛 GitHub Actions: lint
src/axolotl/utils/lora_merge_efficient.py
[error] ruff-format reformatted 1 file during pre-commit; the hook failed in CI (exit code 1). Run 'pre-commit run --all-files' locally to apply formatting changes.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/lora_merge_efficient.py (1)
85-96: Install dependencies and verify merge outputs match
Installtorchandsafetensors, then setLEGACY_DIR/EFFICIENT_DIRand run the provided equivalence‐check script to confirm all tensor keys and values match exactly.src/axolotl/cli/merge_lora.py (1)
17-23: Add a--validate-mergeflag to compare both merge paths
Implement an optional--validate-merge(or--self-check) switch indo_merge_lora/do_clithat, when enabled, runs both thelegacyandmemory_efficientmerges (into temp dirs or in-memory), loads all output tensors viasafetensors.torchandtorch.load, verifies the key sets match, and usestorch.allclose(atol=0, rtol=0)on each tensor—erroring or warning on any mismatch. This ensures users can confirm both implementations produce identical merged weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (6)
src/axolotl/utils/lora_merge_efficient.py (6)
27-36: Bug: rstrip(".weight") corrupts keys; remove exact suffix insteadrstrip removes any trailing chars from the set, not the exact substring. This can mis-match LoRA keys.
Apply:
- clean_key = key.rstrip(".weight") + # Remove only the exact ".weight" suffix + clean_key = key[:-7] if key.endswith(".weight") else key
43-55: Bug: listPath is not a constructorThis raises at runtime. Initialize normally.
- shards = list[Path]() + shards: list[Path] = []
197-205: Critical: don’t rename .bin→.safetensors or write pickles with .safetensors extensionThis corrupts outputs and breaks HF index references. Preserve original shard format; ensure CPU tensors before writing.
- output_shard_path = output_path / shard_path.name - if safe_tensors and shard_path.suffix == ".safetensors": - safetensors.torch.save_file( - merged_tensors, output_shard_path, metadata=metadata - ) - else: - if safe_tensors: - output_shard_path = output_shard_path.with_suffix(".safetensors") - torch.save(merged_tensors, output_shard_path) + output_shard_path = output_path / shard_path.name + # Ensure CPU tensors before writing + merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata) + else: + if safe_tensors: + LOG.warning( + "safe_tensors=True requested but input shards are .bin; preserving .bin format " + "to avoid index mismatches. Consider a separate convert step." + ) + torch.save(merged_tensors, output_shard_path)
182-193: Correctness: handle fan_in_fan_out in .bin branch tooMirror the transpose logic here.
- delta = scale * (lora_b_fp32 @ lora_a_fp32) + delta = scale * (lora_b_fp32 @ lora_a_fp32) + if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)): + delta = delta.T
111-114: Bug: LoraConfig used like dict; add zero-division guardfrom_json_file returns an object. Also guard r>0.
- lora_config = LoraConfig.from_json_file(config_file) - scale = lora_config["lora_alpha"] / lora_config["r"] + lora_config = LoraConfig.from_json_file(config_file) + if not getattr(lora_config, "r", None) or lora_config.r <= 0: + raise ValueError("LoRA config 'r' must be > 0") + scale = float(lora_config.lora_alpha) / float(lora_config.r)
147-174: Correctness: handle fan_in_fan_out orientation when merging safetensors shardAdapters trained with fan_in_fan_out=True require a transpose.
- delta = scale * (lora_b_fp32 @ lora_a_fp32) + delta = scale * (lora_b_fp32 @ lora_a_fp32) + if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)): + delta = delta.T
🧹 Nitpick comments (4)
src/axolotl/utils/lora_merge_efficient.py (4)
21-41: Speed up lookups by pre-indexing LoRA A/B onceCurrent O(N_base * N_lora) scan per tensor is avoidable. Build a suffix→(A,B) index once.
Option sketch (new helper outside diff for context):
def build_lora_index(lora_state: dict[str, torch.Tensor]) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: a_map, b_map = {}, {} for k, v in lora_state.items(): if k.endswith(".lora_A.weight"): a_map[k[: -len(".lora_A.weight")]] = v elif k.endswith(".lora_B.weight"): b_map[k[: -len(".lora_B.weight")]] = v return {k: (a_map[k], b_map[k]) for k in a_map.keys() & b_map.keys()}Then in the merge loop, resolve with lora_index.get(clean_key).
82-84: Preserve file metadata when copyingUse copy2 to retain mtime/permissions.
- shutil.copy(filepath, output_path) + shutil.copy2(filepath, output_path)
175-181: Safety/efficiency: load .bin shards weights-only on CPUPrefer weights_only=True to avoid executing pickles; we only need tensors.
- state_dict = torch.load( - shard_path, map_location="cpu" - ) # nosec B614: loading trusted model weights + state_dict = torch.load( # nosec B614: loading trusted model weights + shard_path, map_location="cpu", weights_only=True + )
143-174: Optional: keep tensors on CPU but detach/cpu all merged values explicitlyBe explicit to avoid accidental device retention if future changes introduce GPU ops.
- merged_tensors[key] = merged_tensor + merged_tensors[key] = merged_tensor.detach().cpu() @@ - merged_tensors[key] = merged_tensor + merged_tensors[key] = merged_tensor.detach().cpu()Also applies to: 175-196
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/axolotl/utils/lora_merge_efficient.py(1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (1)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
LoraConfig(28-191)src/axolotl/utils/logging.py (1)
get_logger(42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py
109-109: Avoid specifying long messages outside the exception class
(TRY003)
120-122: Avoid specifying long messages outside the exception class
(TRY003)
135-135: Avoid specifying long messages outside the exception class
(TRY003)
🪛 GitHub Actions: lint
src/axolotl/utils/lora_merge_efficient.py
[error] 172-172: ruff-format: File reformatted by pre-commit; please re-run pre-commit or commit formatting changes.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/utils/lora_merge_efficient.py (3)
47-54: Shard patterns look goodUsing model*.safetensors and pytorch_model*.bin matches HF conventions (per our prior learning).
140-174: Manual parity and VRAM profiling required
The automated script couldn’t run in this environment (missingtorch), so please verify in your setup:
- Compare all tensors in
legacy_merged/vs.memory_eff_merged/byte-for-byte (using strictallclosefor floats andequalfor ints).- Profile peak GPU memory during both merge routines (e.g.
nvidia-smi --query-gpu=memory.used --loop-ms=500) and report savings.
Also apply the same checks to the code blocks at lines 175–196 and 197–205.
1-213: Install and run pre-commit hooks to apply ruff-format fixes
CI is failing on ruff-format; ensure you have pre-commit installed and run:pre-commit install pre-commit run --all-filesto reflow and commit the formatting changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (8)
src/axolotl/utils/lora_merge_efficient.py (8)
27-27: Do not use rstrip(".weight") — it corrupts keys
rstrip(".weight")removes any trailing chars from the set {'.', 'w', 'e', 'i', 'g', 'h', 't'}, not the exact suffix. This can lead to false matches. Use explicit suffix removal.- clean_key = key.rstrip(".weight") + clean_key = key[:-7] if key.endswith(".weight") else key
43-55: Runtime error:list[Path]()is not a constructorThis will raise at runtime. Initialize as a standard list.
- shards = list[Path]() + shards: list[Path] = []
112-115:LoraConfigis an object, not a dict; guard against zeror
LoraConfig.from_json_filereturns an object;lora_config["..."]will fail. Also prevent division by zero.- lora_config = LoraConfig.from_json_file(config_file) - scale = lora_config["lora_alpha"] / lora_config["r"] + lora_config = LoraConfig.from_json_file(str(config_file)) + if not getattr(lora_config, "r", None) or float(lora_config.r) <= 0: + raise ValueError("LoRA config 'r' must be > 0") + scale = float(getattr(lora_config, "lora_alpha", 1.0)) / float(lora_config.r)
132-136: Keep LoRA state on CPU; avoid bulk device transferMoving the entire adapter to GPU defeats the memory-efficient design and can spike VRAM.
- if device != "cpu": - LOG.info(f"Moving LoRA weights to {device}") - for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"): - lora_state[key] = value.to(device) + LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge")
151-159: Open safetensors shards on CPU to control VRAM; move per-tensor for computeLoading tensors directly on GPU risks VRAM blowups and makes serialization harder.
- if shard_path.suffix == ".safetensors": - with safetensors.safe_open(shard_path, framework="pt", device=device) as f: + if shard_path.suffix == ".safetensors": + # Always open on CPU; move specific tensors to `device` only for compute + with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
179-200: Load .bin shards on CPU; compute per-tensor on device; store CPU tensors; preferweights_only=TrueThis keeps memory bounded, improves safety, and aligns with the safetensors branch.
- state_dict = torch.load( - shard_path, map_location=device - ) # nosec B614: loading trusted model weights + state_dict = torch.load( # nosec B614: loading trusted model weights + shard_path, map_location="cpu", weights_only=True + ) @@ - original_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - lora_a_fp32 = lora_a.to(torch.float32) - lora_b_fp32 = lora_b.to(torch.float32) - - delta = scale * (lora_b_fp32 @ lora_a_fp32) - merged_tensor = (tensor_fp32 + delta).to(original_dtype) - merged_tensors[key] = merged_tensor - - del tensor_fp32, lora_a_fp32, lora_b_fp32, delta + original_dtype = tensor.dtype + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)): + delta = delta.T + merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu() + del base_fp32, a_fp32, b_fp32, delta else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.detach().cpu()
201-209: Do not mislabel file formats; preserve original shard extension and write CPU tensorsCurrently, when input is
.binandsafe_tensors=True, you rename to.safetensorsbut still usetorch.save, producing invalid safetensors and breaking HF indices. Also ensure all tensors are on CPU before saving.- output_shard_path = output_path / shard_path.name - if safe_tensors and shard_path.suffix == ".safetensors": - safetensors.torch.save_file( - merged_tensors, output_shard_path, metadata=metadata - ) - else: - if safe_tensors: - output_shard_path = output_shard_path.with_suffix(".safetensors") - torch.save(merged_tensors, output_shard_path) + output_shard_path = output_path / shard_path.name + # Ensure CPU tensors before writing + merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata) + else: + if safe_tensors: + LOG.warning( + "safe_tensors=True requested but input shards are .bin; preserving .bin format " + "to avoid index mismatches." + ) + torch.save(merged_tensors, output_shard_path)
167-175: Compute ondevice, respectfan_in_fan_out, and store CPU tensors for save
- Do compute on
devicein FP32, but store merged results on CPU.- Handle PEFT’s
fan_in_fan_out=Trueby transposing delta.- Avoid keeping GPU tensors in
merged_tensors(safetensors requires CPU tensors).- original_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - lora_a_fp32 = lora_a.to(torch.float32) - lora_b_fp32 = lora_b.to(torch.float32) - - delta = scale * (lora_b_fp32 @ lora_a_fp32) - merged_tensor = (tensor_fp32 + delta).to(original_dtype) - merged_tensors[key] = merged_tensor - del tensor_fp32, lora_a_fp32, lora_b_fp32, delta + original_dtype = tensor.dtype + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)): + delta = delta.T + merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu() + del base_fp32, a_fp32, b_fp32, delta
🧹 Nitpick comments (5)
src/axolotl/utils/lora_merge_efficient.py (5)
32-38: Stop scanning once both LoRA tensors are foundMicro-optimization: break early to avoid O(N) scan over all adapter tensors per key.
- for lora_key, lora_weight in lora_state.items(): + for lora_key, lora_weight in lora_state.items(): if lora_key.endswith(f"{clean_key}.lora_A.weight"): lora_a = lora_weight elif lora_key.endswith(f"{clean_key}.lora_B.weight"): lora_b = lora_weight + if lora_a is not None and lora_b is not None: + break
211-214: Guard CUDA cache callsAvoid calling CUDA APIs when CUDA isn’t available.
- if device != "cpu": - torch.cuda.empty_cache() + if device != "cpu" and torch.cuda.is_available(): + torch.cuda.empty_cache()
21-41: Performance: avoid O(N×M) adapter scans by pre-indexing LoRA keysCurrent per-tensor scan over the entire
lora_stateis costly on large models. Build a suffix-index map once.Happy to provide a follow-up patch that constructs:
- a map from cleaned base key suffix → (A, B)
- or a trie/suffix map keyed by last two path segments
Let me know if you want the diff.
Also applies to: 137-146
112-115: Correctness: ensure LoRA orientation (fan_in_fan_out) is supportedAdapters trained with
fan_in_fan_out=Truerequire transposition during merge; the proposed diffs add this. Please add a unit test exercising this flag.Also applies to: 167-175, 188-194
137-143: Log improvements: shard count and adapter pathAdd low-noise debug logs to aid support; aligns with reviewer suggestions.
- LOG.debug(f"Found {len(model_shards)} model shards") + LOG.debug(f"Found {len(model_shards)} model shards in {base_model_path}")
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/axolotl/utils/lora_merge_efficient.py(1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (1)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
LoraConfig(28-191)src/axolotl/utils/logging.py (1)
get_logger(42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py
110-110: Avoid specifying long messages outside the exception class
(TRY003)
121-123: Avoid specifying long messages outside the exception class
(TRY003)
139-139: Avoid specifying long messages outside the exception class
(TRY003)
🪛 GitHub Actions: lint
src/axolotl/utils/lora_merge_efficient.py
[error] 176-179: ruff-format: 1 file reformatted by the hook; please re-run 'pre-commit run --all-files' and commit the changes.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/lora_merge_efficient.py (2)
139-139: Run pre-commit locally and centralize exception messagesPre-commit isn’t available here; please run
pre-commit run --all-fileslocally to apply any required formatting fixes. Also shorten or centralize long exception messages (e.g. at lines 139 and 215).
215-215: Add automated parity and VRAM measurement tests
CI failed to run the parity script due to missing torch—ensure your test environment installs PyTorch and add a CI step that:
- Runs merge_lora_sharded_efficient and merge_lora on the same base+adapter
- Asserts all tensors match (rtol=1e-4, atol=1e-5)
- Captures peak GPU memory during each merge to validate the expected savings
| def merge_lora_sharded_efficient( | ||
| base_model_path: Union[str, Path], | ||
| lora_adapter_path: Union[str, Path], | ||
| output_path: Union[str, Path], | ||
| device: str = "cuda", | ||
| safe_tensors: bool = True, | ||
| ) -> None: | ||
| """ | ||
| Memory-efficient LoRA merging that processes shards individually | ||
| without loading the full model into memory. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Safety: prevent in-place overwrite of source directory
If output_path equals base_model_path, you risk clobbering source files. Add a guard and fail fast.
output_path = Path(output_path)
@@
- os.makedirs(output_path, exist_ok=True)
+ if output_path.resolve() == base_model_path.resolve():
+ raise ValueError("output_path must differ from base_model_path to avoid overwriting source shards")
+ os.makedirs(output_path, exist_ok=True)Also applies to: 101-106
🤖 Prompt for AI Agents
In src/axolotl/utils/lora_merge_efficient.py around lines 86-96 (and similarly
at 101-106), add a fail-fast guard that prevents output_path from being the same
as base_model_path (and also disallow matching lora_adapter_path) to avoid
in-place overwrites; implement by converting inputs to pathlib.Path and
comparing resolved absolute paths (Path(...).resolve()) and if any matches raise
a clear ValueError (or SystemExit) with a message like "output_path must be
different from base_model_path/lora_adapter_path" before performing any file
operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (11)
src/axolotl/utils/lora_merge_efficient.py (9)
11-14: Fix import: NameError on safetensors.safe_openYou call safetensors.safe_open but never import the top-level module.
-import safetensors.torch +import safetensors # needed for safetensors.safe_open +import safetensors.torch
45-49: Initialize list correctlylistPath is a type hint, not a constructor; this will raise.
- shards = list[Path]() + shards: list[Path] = []
72-81: Skip all HF weight files not selected as shardsAvoid copying opposite-format weights into output (mixed repo).
- if filepath.name in shard_names: - continue - if filepath.suffix == ".gguf": - continue - if filepath.name.startswith("model") and filepath.suffix == ".safetensors": - continue + if filepath.name in shard_names: + continue + # Skip any other HF weight files + if ( + (filepath.suffix in {".safetensors", ".bin"}) + and (filepath.name.startswith("model") or filepath.name.startswith("pytorch_model")) + ): + continue + if filepath.suffix == ".gguf": + continue
101-106: Guard against in-place overwrite of source dirsPrevent output_path == base_model_path or lora_adapter_path.
- os.makedirs(output_path, exist_ok=True) + if output_path.resolve() in {base_model_path.resolve(), lora_adapter_path.resolve()}: + raise ValueError("output_path must differ from base_model_path and lora_adapter_path") + os.makedirs(output_path, exist_ok=True)
112-115: Access LoraConfig attributes, validate r/alphaLoraConfig is an object, not a dict; also guard r>0 and presence of lora_alpha.
- lora_config = LoraConfig.from_json_file(config_file) - scale = lora_config["lora_alpha"] / lora_config["r"] + lora_config = LoraConfig.from_json_file(config_file) + if not getattr(lora_config, "r", None) or not getattr(lora_config, "lora_alpha", None): + raise ValueError("LoRA config missing required fields: 'r' and 'lora_alpha'") + if int(lora_config.r) == 0: + raise ValueError("LoRA config 'r' must be > 0") + scale = float(lora_config.lora_alpha) / float(lora_config.r)
132-136: Keep LoRA on CPU; avoid bulk GPU transferBulk moving defeats “memory-efficient” goal and can OOM.
- if device != "cpu": - LOG.info(f"Moving LoRA weights to {device}") - for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"): - lora_state[key] = value.to(device) + LOG.debug("Keeping LoRA weights on CPU; moving per-tensor during merge")
151-176: Open safetensors on CPU; per-tensor device compute; handle fan_in_fan_out; store CPU tensorsPrevents VRAM spikes and ensures correct orientation; outputs must be CPU tensors for serialization.
- if shard_path.suffix == ".safetensors": - with safetensors.safe_open(shard_path, framework="pt", device=device) as f: + if shard_path.suffix == ".safetensors": + with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f: if hasattr(f, "metadata") and f.metadata(): metadata = f.metadata() @@ - tensor = f.get_tensor(key) + tensor = f.get_tensor(key) # CPU tensor lora_a, lora_b = find_lora_weights(lora_state, key) @@ - original_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - - delta = scale * ( - lora_b.to(torch.float32) @ lora_a.to(torch.float32) - ) - - merged_tensor = (tensor_fp32 + delta).to(original_dtype) - merged_tensors[key] = merged_tensor + original_dtype = tensor.dtype + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)): + delta = delta.T + merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu() else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.detach().cpu()
179-196: Load .bin shards on CPU; per-tensor device compute; store CPU tensors; handle fan_in_fan_outAvoids GPU load of entire shard; adds robustness with weights_only.
- else: - state_dict = torch.load( - shard_path, map_location=device - ) # nosec B614: loading trusted model weights + else: + state_dict = torch.load( # nosec B614: loading trusted model weights + shard_path, map_location="cpu", weights_only=True + ) for key, tensor in state_dict.items(): total_tensors += 1 lora_a, lora_b = find_lora_weights(lora_state, key) @@ - if lora_a is not None and lora_b is not None: + if lora_a is not None and lora_b is not None: merged_count += 1 original_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - delta = scale * ( - lora_b.to(torch.float32) @ lora_a.to(torch.float32) - ) - merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype) + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)): + delta = delta.T + merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu() else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.detach().cpu()
197-206: Do not rename .bin shards to .safetensors; preserve original formatRenaming while using torch.save corrupts repos and breaks indices. Always save CPU tensors.
- output_shard_path = output_path / shard_path.name - if safe_tensors and shard_path.suffix == ".safetensors": - safetensors.torch.save_file( - merged_tensors, output_shard_path, metadata=metadata - ) - else: - if safe_tensors: - output_shard_path = output_shard_path.with_suffix(".safetensors") - torch.save(merged_tensors, output_shard_path) + output_shard_path = output_path / shard_path.name + # Ensure CPU tensors before writing + merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata) + else: + if safe_tensors: + LOG.warning( + "safe_tensors=True requested but input shards are .bin; preserving .bin format to avoid index mismatches." + ) + torch.save(merged_tensors, output_shard_path)src/axolotl/cli/merge_lora.py (2)
24-36: Normalize merge_method; accept 'standard'; broaden fallback; log exceptionMatches PR wording and avoids silent failures.
- merge_method = getattr(cfg, "merge_method", "memory_efficient") - LOG.info(f"Using {merge_method} LoRA merge method") - - if merge_method == "legacy": - _do_merge_lora_legacy(cfg=cfg) - else: - try: - _do_merge_lora_efficient(cfg=cfg) - except RuntimeError as e: - LOG.error(f"Memory-efficient merge failed: {e}") - LOG.info("Falling back to legacy merge method...") - _do_merge_lora_legacy(cfg=cfg) + merge_method = str(getattr(cfg, "merge_method", "memory_efficient")).lower().replace("-", "_") + if merge_method in {"legacy", "standard"}: + LOG.info("Using legacy LoRA merge method...") + _do_merge_lora_legacy(cfg=cfg) + else: + LOG.info("Using memory-efficient LoRA merge method...") + try: + _do_merge_lora_efficient(cfg=cfg) + except Exception: + LOG.exception("Memory-efficient merge failed; falling back to legacy merge...") + _do_merge_lora_legacy(cfg=cfg)
80-91: Pass device explicitly; default to CPU when CUDA unavailablePrevents crashes on CPU-only hosts and avoids hidden GPU use.
- output_path = Path(cfg.output_dir) / "merged" - safe_tensors = getattr(cfg, "save_safetensors", True) + output_path = Path(cfg.output_dir) / "merged" + safe_tensors = getattr(cfg, "save_safetensors", True) + # Choose device: prefer CUDA if available + try: + import torch # local import + device = "cuda" if torch.cuda.is_available() else "cpu" + except Exception: + device = "cpu" @@ - merge_lora_sharded_efficient( + merge_lora_sharded_efficient( base_model_path=cfg.base_model, lora_adapter_path=cfg.lora_model_dir, output_path=output_path, - safe_tensors=safe_tensors, + safe_tensors=safe_tensors, + device=device, )
🧹 Nitpick comments (1)
src/axolotl/utils/lora_merge_efficient.py (1)
208-209: Guard CUDA cache callsAvoid calling CUDA APIs on CPU-only hosts.
- if device != "cpu": - torch.cuda.empty_cache() + if device != "cpu" and torch.cuda.is_available(): + torch.cuda.empty_cache()
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
src/axolotl/cli/merge_lora.py(3 hunks)src/axolotl/utils/lora_merge_efficient.py(1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
Applied to files:
src/axolotl/cli/merge_lora.pysrc/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (2)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
merge_lora_sharded_efficient(86-211)src/axolotl/utils/logging.py (1)
get_logger(42-49)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
LoraConfig(28-191)src/axolotl/utils/logging.py (1)
get_logger(42-49)
🪛 Ruff (0.12.2)
src/axolotl/cli/merge_lora.py
33-33: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
src/axolotl/utils/lora_merge_efficient.py
110-110: Avoid specifying long messages outside the exception class
(TRY003)
121-123: Avoid specifying long messages outside the exception class
(TRY003)
139-139: Avoid specifying long messages outside the exception class
(TRY003)
🪛 GitHub Actions: lint
src/axolotl/utils/lora_merge_efficient.py
[error] 176-176: ruff-format reformatted this file during pre-commit; the pre-commit hook failed with exit code 1. Run 'pre-commit run --all-files' and commit the updated file to apply formatting.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.8.0)
🔇 Additional comments (1)
src/axolotl/utils/lora_merge_efficient.py (1)
1-212: Install and run pre-commit to fix formatting
pre-commitwasn’t found in this environment. Ensure it’s installed and runpre-commit run --all-filesto apply lint/format fixes before committing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (4)
src/axolotl/cli/merge_lora.py (4)
80-80: Prefer debug level for this trace message (consistency with other suggestions).- LOG.info("Using memory-efficient LoRA merging method...") + LOG.debug("Using memory-efficient LoRA merging method...")
127-127: Fix wording: this path points to adapter weights, not a merged model directory.- f"Target directory for LoRA merged model does not exist: `{parsed_cfg.lora_model_dir}`" + f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`"
24-35: Normalize merge_method, accept 'standard', and guarantee fallback with stacktrace.Without normalization,
merge_method: standard(as advertised in the PR) routes to the efficient path unintentionally. Also, catching only RuntimeError can miss common failures (e.g., CPU-only hosts raising AssertionError when device defaults to "cuda"), preventing the promised fallback. Use LOG.exception to keep the stacktrace (ruff TRY400).Apply:
- merge_method = getattr(cfg, "merge_method", "memory_efficient") - LOG.info(f"Using {merge_method} LoRA merge method") - - if merge_method == "legacy": - _do_merge_lora_legacy(cfg=cfg) - else: - try: - _do_merge_lora_efficient(cfg=cfg) - except RuntimeError as e: - LOG.error(f"Memory-efficient merge failed: {e}") - LOG.info("Falling back to legacy merge method...") - _do_merge_lora_legacy(cfg=cfg) + merge_method = str(getattr(cfg, "merge_method", "memory_efficient") or "memory_efficient").strip().lower().replace("-", "_") + if merge_method in {"legacy", "standard"}: + LOG.debug("Using legacy LoRA merging method...") + _do_merge_lora_legacy(cfg=cfg) + else: + LOG.debug("Using memory-efficient LoRA merging method...") + try: + _do_merge_lora_efficient(cfg=cfg) + except Exception: + LOG.exception("Memory-efficient merge failed; falling back to legacy.") + _do_merge_lora_legacy(cfg=cfg)
82-91: Pass device explicitly to avoid CPU-only crashes and unnecessary fallbacks.The efficient helper defaults to device="cuda". On CPU-only hosts this will throw before merging. Pick device based on CUDA availability and pass it through.
output_path = Path(cfg.output_dir) / "merged" safe_tensors = getattr(cfg, "save_safetensors", True) # Perform memory-efficient merge + try: + import torch # local import to keep CLI lightweight + device = "cuda" if torch.cuda.is_available() else "cpu" + except Exception: + device = "cpu" merge_lora_sharded_efficient( base_model_path=cfg.base_model, lora_adapter_path=cfg.lora_model_dir, output_path=output_path, - safe_tensors=safe_tensors, + safe_tensors=safe_tensors, + device=device, )
🧹 Nitpick comments (1)
src/axolotl/cli/merge_lora.py (1)
45-46: Unify safe-serialization default with efficient path.Use getattr with a default instead of identity comparison for consistent behavior.
- safe_serialization = cfg.save_safetensors is True + safe_serialization = getattr(cfg, "save_safetensors", True)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/axolotl/cli/merge_lora.py(3 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
Applied to files:
src/axolotl/cli/merge_lora.py
🧬 Code graph analysis (1)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
merge_lora_sharded_efficient(86-211)src/axolotl/utils/logging.py (1)
get_logger(42-49)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
🪛 Ruff (0.12.2)
src/axolotl/cli/merge_lora.py
33-33: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: preview
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.6.0)
🔇 Additional comments (4)
src/axolotl/cli/merge_lora.py (4)
12-12: Import of efficient merge helper looks good.
19-19: Docstring accurately reflects dual strategies.
38-44: Legacy merge entrypoint/readability is fine.
93-93: Success debug message reads well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (9)
src/axolotl/utils/lora_merge_efficient.py (9)
27-40: Do not use rstrip(".weight"); exact suffix and O(1) lookups.rstrip removes any trailing chars in the set ".weight", corrupting keys and causing mis-merges.
Apply:
- clean_key = key.rstrip(".weight") - - lora_a = None - lora_b = None - - for lora_key, lora_weight in lora_state.items(): - if lora_key.endswith(f"{clean_key}.lora_A.weight"): - lora_a = lora_weight - elif lora_key.endswith(f"{clean_key}.lora_B.weight"): - lora_b = lora_weight - - if lora_a is not None and lora_b is not None: - return lora_a, lora_b - return None, None + clean_key = key[:-7] if key.endswith(".weight") else key + a_key = f"{clean_key}.lora_A.weight" + b_key = f"{clean_key}.lora_B.weight" + lora_a = lora_state.get(a_key) + lora_b = lora_state.get(b_key) + if lora_a is not None and lora_b is not None: + return lora_a, lora_b + return None, None
45-55: listPath is not a constructor.This will raise at runtime.
Apply:
- shards = list[Path]() + shards: list[Path] = []
97-106: Guard against overwriting source dirs.Disallow output_path == base_model_path or == lora_adapter_path.
Apply:
output_path = Path(output_path) @@ - os.makedirs(output_path, exist_ok=True) + if output_path.resolve() == base_model_path.resolve(): + raise ValueError("output_path must differ from base_model_path to avoid overwriting source shards") + if output_path.resolve() == lora_adapter_path.resolve(): + raise ValueError("output_path must differ from lora_adapter_path") + os.makedirs(output_path, exist_ok=True)
132-136: Keep LoRA on CPU; avoid bulk device transfer.Bulk .to(device) defeats the memory-efficient goal and can OOM.
Apply:
- if device != "cpu": - LOG.debug(f"Moving LoRA weights to {device}") - for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"): - lora_state[key] = value.to(device) + # Keep LoRA on CPU; move per-tensor during merge + LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge")
195-204: Do not write PyTorch pickles with a .safetensors extension; preserve input shard format.Renaming .bin to .safetensors while calling torch.save corrupts outputs and breaks HF indices.
Apply:
- output_shard_path = output_path / shard_path.name - if safe_tensors and shard_path.suffix == ".safetensors": - safetensors.torch.save_file( - merged_tensors, output_shard_path, metadata=metadata - ) - else: - if safe_tensors: - output_shard_path = output_shard_path.with_suffix(".safetensors") - torch.save(merged_tensors, output_shard_path) + output_shard_path = output_path / shard_path.name + # Ensure CPU tensors before writing + merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()} + if shard_path.suffix == ".safetensors": + safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata) + else: + if safe_tensors: + LOG.warning( + "safe_tensors=True requested but input shards are .bin; preserving .bin format " + "to avoid index mismatches. Conversion is not implemented here." + ) + torch.save(merged_tensors, output_shard_path)
11-14: Fix import for safetensors.safe_open (NameError).You call safetensors.safe_open but never import the top-level safetensors pkg.
Apply:
-import safetensors.torch +import safetensors # needed for safetensors.safe_open +import safetensors.torch
112-116: LoraConfig is an object, not a dict; also guard r > 0.Current code will raise; division by zero risk.
Apply:
- lora_config = LoraConfig.from_json_file(config_file) - scale = lora_config["lora_alpha"] / lora_config["r"] + lora_config = LoraConfig.from_json_file(config_file) + if not getattr(lora_config, "r", None): + raise ValueError("LoRA config 'r' must be > 0") + scale = float(lora_config.lora_alpha) / float(lora_config.r)
151-177: Open safetensors on CPU; per-tensor JIT moves; support fan_in_fan_out; store results on CPU.Current code opens on device, does FP32 on host device without ensuring device alignment, and ignores fan_in_fan_out.
Apply:
- if shard_path.suffix == ".safetensors": - with safetensors.safe_open(shard_path, framework="pt", device=device) as f: + if shard_path.suffix == ".safetensors": + with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f: if hasattr(f, "metadata") and f.metadata(): metadata = f.metadata() @@ - for key in f.keys(): + for key in f.keys(): total_tensors += 1 - tensor = f.get_tensor(key) + tensor = f.get_tensor(key) # CPU tensor lora_a, lora_b = find_lora_weights(lora_state, key) @@ - if lora_a is not None and lora_b is not None: + if lora_a is not None and lora_b is not None: merged_count += 1 @@ - original_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - - delta = scale * ( - lora_b.to(torch.float32) @ lora_a.to(torch.float32) - ) - - merged_tensor = (tensor_fp32 + delta).to(original_dtype) - merged_tensors[key] = merged_tensor + original_dtype = tensor.dtype + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)): + delta = delta.T + merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu() else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.detach().cpu()
178-194: Load .bin shards on CPU; weights_only; per-tensor JIT moves; support fan_in_fan_out; store on CPU.Avoid GPU preload; ensure robustness and memory efficiency.
Apply:
- else: - state_dict = torch.load(shard_path, map_location=device) # nosec B614: loading trusted model weights + else: + state_dict = torch.load( # nosec B614: loading trusted model weights + shard_path, map_location="cpu", weights_only=True + ) for key, tensor in state_dict.items(): total_tensors += 1 lora_a, lora_b = find_lora_weights(lora_state, key) @@ - if lora_a is not None and lora_b is not None: + if lora_a is not None and lora_b is not None: merged_count += 1 original_dtype = tensor.dtype - tensor_fp32 = tensor.to(torch.float32) - delta = scale * ( - lora_b.to(torch.float32) @ lora_a.to(torch.float32) - ) - merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype) + base_fp32 = tensor.to(device).to(torch.float32) + a_fp32 = lora_a.to(device).to(torch.float32) + b_fp32 = lora_b.to(device).to(torch.float32) + delta = scale * (b_fp32 @ a_fp32) + if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)): + delta = delta.T + merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu() else: - merged_tensors[key] = tensor + merged_tensors[key] = tensor.detach().cpu()
🧹 Nitpick comments (2)
src/axolotl/utils/lora_merge_efficient.py (2)
72-81: Skip opposite-format HF weight files when copying non-model files.Prevents mixed-format outputs (e.g., copying stray .bin when merging safetensors).
Apply:
- if filepath.name.startswith("model") and filepath.suffix == ".safetensors": - continue + # Skip any other HF weight files not selected as shards + if ( + (filepath.name.startswith("model") and filepath.suffix == ".safetensors") + or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin") + ): + continue
205-209: Free host memory too.Add gc.collect() after each shard; keep empty_cache as-is.
Apply:
+ import gc del merged_tensors if device != "cpu": torch.cuda.empty_cache() + gc.collect()
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/axolotl/utils/lora_merge_efficient.py(1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.
Applied to files:
src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (1)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
LoraConfig(28-191)src/axolotl/utils/logging.py (1)
get_logger(42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py
110-110: Avoid specifying long messages outside the exception class
(TRY003)
121-123: Avoid specifying long messages outside the exception class
(TRY003)
139-139: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: preview
🔇 Additional comments (1)
src/axolotl/utils/lora_merge_efficient.py (1)
86-93: Verify parity with legacy merge
Install torch and safetensors in your environment, then run the parity-check script on a small model. Confirm that the maximum absolute difference is 0.0 and there are no mismatched keys.
|
memory usage for both of the merges |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw you added a VRAM / time comparison, can you specify which model you used?
src/axolotl/cli/merge_lora.py
Outdated
| merge_method = ( | ||
| str(getattr(cfg, "merge_method", "")).strip().lower().replace("-", "_") | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge_method can only take values: Literal["legacy", "memory_efficient"] so you don't need this string handling.
src/axolotl/cli/merge_lora.py
Outdated
| merge_method = ( | ||
| str(getattr(cfg, "merge_method", "")).strip().lower().replace("-", "_") | ||
| ) | ||
| if merge_method in {"legacy", "standard"}: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"standard" doesn't exist
src/axolotl/cli/merge_lora.py
Outdated
| try: | ||
| _do_merge_lora_efficient(cfg=cfg) | ||
| except Exception: # pylint: disable=broad-exception-caught | ||
| LOG.exception("Memory-efficient merge failed; falling back to legacy.") | ||
| _do_merge_lora_legacy(cfg=cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbh I'd rather have a hard failure here so we know if something is broken
| try: | |
| _do_merge_lora_efficient(cfg=cfg) | |
| except Exception: # pylint: disable=broad-exception-caught | |
| LOG.exception("Memory-efficient merge failed; falling back to legacy.") | |
| _do_merge_lora_legacy(cfg=cfg) | |
| _do_merge_lora_efficient(cfg=cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are unsupported combinations (you mentioned DoRA, RSLoRA), we should validate this in the pydantic model and raise an error there.
| lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614 | ||
| except TypeError: | ||
| lora_state = torch.load(lora_file, map_location="cpu") # nosec B614 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this try/except? can you choose one loading method and stick to it?
| output_path = Path(output_path) | ||
|
|
||
| if "/" in str(base_model_path) and not base_model_path.exists(): | ||
| from huggingface_hub import snapshot_download |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be a toplevel import
| def copy_non_model_files( | ||
| input_path: Path, output_path: Path, model_shards: list[Path] | ||
| ) -> None: | ||
| """ | ||
| Copy all non-model files to the output directory. | ||
|
|
||
| Args: | ||
| input_path: Source directory | ||
| output_path: Destination directory | ||
| model_shards: List of model shard files to skip | ||
| """ | ||
| LOG.info("Copying non-model files to output directory...") | ||
|
|
||
| shard_names = {shard.name for shard in model_shards} | ||
|
|
||
| for filepath in input_path.glob("*"): | ||
| if filepath.is_dir(): | ||
| continue | ||
| if filepath.name in shard_names: | ||
| continue | ||
| if ( | ||
| filepath.name.startswith("model") and filepath.suffix == ".safetensors" | ||
| ) or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin"): | ||
| continue | ||
| if filepath.suffix == ".gguf": | ||
| continue | ||
|
|
||
| LOG.debug(f"Copying {filepath.name} to output") | ||
| shutil.copy2(filepath, output_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of this? Can you check that we / transformers don't already have a utility for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove this if we dont need files like config.json and all in output directory . not sure yes or no (did'nt found anything similar to this in transformers) for separating non-modules
| LOG.warning( | ||
| "safe_tensors=True requested but input shards are .bin; preserving .bin format " | ||
| "to avoid index mismatches." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a bit confusing. if the user requests safe_tensors, shouldn't we convert them to safetensors?
| original_dtype = tensor.dtype | ||
| base_fp32 = tensor.to(device).to(torch.float32) | ||
| a_fp32 = lora_a.to(device).to(torch.float32) | ||
| b_fp32 = lora_b.to(device).to(torch.float32) | ||
| delta = scale * (b_fp32 @ a_fp32) | ||
| if bool( | ||
| lora_config_dict.get("fan_in_fan_out", False) | ||
| or lora_config_dict.get("lora_fan_in_fan_out", False) | ||
| ): | ||
| delta = delta.T | ||
| merged_tensors[key] = ( | ||
| (base_fp32 + delta).to(original_dtype).detach().cpu() | ||
| ) | ||
| del base_fp32, a_fp32, b_fp32, delta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this appears to be duplicated below, can be factored out into a helper method
Co-authored-by: Dan Saunders <[email protected]>
Co-authored-by: Dan Saunders <[email protected]>
Description
feature 1 : The merge-lora script does not load the model into memory, period. It just iterates through each of the bin or safetensors shards and applies the lora to each module as it needs. It's extremely efficient compared to the standard approach.
new file
lora_merge_efficientcore implementationnew parameter
merge_method: standard /memory efficientMotivation and Context
#1679
references
qlora-pipe/tools/merge_lora.py
Tests
tested with examples/llama-3/qlora-1b.yml
with tiny llama 1 b instruct and merge_methode:memory efficient
Summary by CodeRabbit
New Features
Chores
Documentation