Skip to content

[peft] fix: Add replica_id handling for dense LoRA adapters with TP > 1#2252

Open
yaoyu-33 wants to merge 2 commits intomainfrom
fix/dense-lora-adapter-tp-replica-id
Open

[peft] fix: Add replica_id handling for dense LoRA adapters with TP > 1#2252
yaoyu-33 wants to merge 2 commits intomainfrom
fix/dense-lora-adapter-tp-replica-id

Conversation

@yaoyu-33
Copy link
Contributor

@yaoyu-33 yaoyu-33 commented Feb 6, 2026

Summary

When using LoRA adapters on dense layers (non-MoE) with TP > 1, only one TP shard was being saved during checkpointing. This caused significant training-inference loss discrepancy because other TP ranks loaded zero/uninitialized adapter weights.

Root Cause

In ParallelLinearAdapter.sharded_state_dict(), there was special replica_id handling for expert adapters (added in PR #1564, related to verl-project/verl#4303), but dense adapters never received the equivalent fix.

For dense adapters with TP > 1:

  • Each TP rank generates a ShardedTensor with the same replica_id
  • Shards should be distinguished by global_offset, but during the PEFT-filtered checkpoint save, TP shards are incorrectly deduplicated
  • Result: Only TP rank 0's shard is saved

Solution

Add replica_id handling for dense adapters similar to expert adapters. When TP > 1, the replica_id is adjusted to include the TP rank, ensuring each TP shard is correctly identified and saved during PEFT-filtered checkpoint saves.

Changes

  • Add replica_id adjustment for dense adapters with TP > 1 in ParallelLinearAdapter.sharded_state_dict()
  • Add unit tests for the fix covering:
    • Dense adapters with TP > 1 (replica_id correctly updated)
    • Dense adapters with TP = 1 (no change to replica_id)
    • Expert adapters (still use EP-based replica_id, not affected)

Testing

  • Unit tests added
  • Functional tests with multi-GPU TP > 1 (requires CI)

Summary by CodeRabbit

  • Bug Fixes

    • Fixed replica ID assignment for adapter checkpoints when using tensor model parallel configurations, ensuring correct deduplication during checkpoint saving.
  • Tests

    • Added comprehensive test coverage for dense and expert adapter scenarios across different parallel configurations.

When using LoRA adapters on dense layers (non-MoE) with TP > 1, only one
TP shard was being saved during checkpointing. This caused significant
training-inference loss discrepancy because other TP ranks loaded
zero/uninitialized adapter weights.

The fix for expert adapters already existed (PR #1564, related to
verl-project/verl#4303), but dense adapters never received the
equivalent fix.

This commit adds replica_id handling for dense adapters similar to
expert adapters. When TP > 1, the replica_id is adjusted to include
the TP rank, ensuring each TP shard is correctly identified and saved
during PEFT-filtered checkpoint saves.

Changes:
- Add replica_id adjustment for dense adapters with TP > 1 in
  ParallelLinearAdapter.sharded_state_dict()
- Add unit tests for the fix covering:
  - Dense adapters with TP > 1 (replica_id correctly updated)
  - Dense adapters with TP = 1 (no change to replica_id)
  - Expert adapters (still use EP-based replica_id, not affected)
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 6, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yaoyu-33
Copy link
Contributor Author

yaoyu-33 commented Feb 6, 2026

/ok to test acaef1a

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

This pull request modifies PEFT state dict handling for dense adapters with tensor model parallel (TP) size greater than 1, adjusting replica_id values for correct checkpoint deduplication. Comprehensive test coverage is added for parallel linear adapter state dict scenarios across different parallelism configurations.

Changes

Cohort / File(s) Summary
PEFT State Dict Logic
src/megatron/bridge/peft/utils.py
Removed in-function import of parallel_state and added new else branch for dense adapters with TP > 1 to adjust replica_id to (original_dim0, tp_rank, original_dim2) across linear_in_sd and linear_out_sd for correct deduplication when saving PEFT-filtered checkpoints.
Test Coverage
tests/unit_tests/peft/test_utils.py
Added three new test cases for ParallelLinearAdapter sharded state dict handling: one for dense adapters with TP > 1 verifying replica_id includes TP rank, one for TP = 1 verifying no changes, and one for expert adapters verifying EP-based calculation is unaffected by dense adapter fix.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Test Results For Major Changes ✅ Passed PR qualifies as minor changes: small focused code change (+14 lines) with comprehensive unit tests (+157 lines) covering key scenarios, though functional tests pending in CI.
Title check ✅ Passed The title clearly and concisely describes the main change: fixing replica_id handling for dense LoRA adapters when TP > 1, which is the primary purpose of the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fix/dense-lora-adapter-tp-replica-id

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

…with TP=2

Add TestLoRAFinetuneTP2 class with two tests:
- test_lora_save_and_resume_tp2: End-to-end save/resume test with TP=2
- test_lora_weights_preserved_after_save_load_tp2: Explicit verification that
  loaded adapter weights exactly match saved weights on all TP ranks

The second test specifically catches the replica_id bug by:
1. Capturing adapter weights before checkpoint save
2. Loading checkpoint into fresh model
3. Comparing loaded vs saved weights
4. Failing with clear error if loaded weights are all zeros (bug symptom)
@yaoyu-33
Copy link
Contributor Author

yaoyu-33 commented Feb 6, 2026

/ok to test 8094b2b

@yaoyu-33 yaoyu-33 changed the title fix(peft): Add replica_id handling for dense LoRA adapters with TP > 1 [peft] fix: Add replica_id handling for dense LoRA adapters with TP > 1 Feb 6, 2026
Copy link

@priyatham-resolve priyatham-resolve left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the fix

return pretrain_checkpoint_dir, pretrain_tensorboard_dir, lora_checkpoint_dir, lora_tensorboard_dir


class TestLoRAFinetuneTP2:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This seems to duplicate all the helper methods from TestLoRAFinetune above. Could share them via a base class or mixin to reduce the ~300 lines of boilerplate? Not a blocker.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants