[peft] fix: Add replica_id handling for dense LoRA adapters with TP > 1#2252
[peft] fix: Add replica_id handling for dense LoRA adapters with TP > 1#2252
Conversation
When using LoRA adapters on dense layers (non-MoE) with TP > 1, only one TP shard was being saved during checkpointing. This caused significant training-inference loss discrepancy because other TP ranks loaded zero/uninitialized adapter weights. The fix for expert adapters already existed (PR #1564, related to verl-project/verl#4303), but dense adapters never received the equivalent fix. This commit adds replica_id handling for dense adapters similar to expert adapters. When TP > 1, the replica_id is adjusted to include the TP rank, ensuring each TP shard is correctly identified and saved during PEFT-filtered checkpoint saves. Changes: - Add replica_id adjustment for dense adapters with TP > 1 in ParallelLinearAdapter.sharded_state_dict() - Add unit tests for the fix covering: - Dense adapters with TP > 1 (replica_id correctly updated) - Dense adapters with TP = 1 (no change to replica_id) - Expert adapters (still use EP-based replica_id, not affected)
|
/ok to test acaef1a |
📝 WalkthroughWalkthroughThis pull request modifies PEFT state dict handling for dense adapters with tensor model parallel (TP) size greater than 1, adjusting Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…with TP=2 Add TestLoRAFinetuneTP2 class with two tests: - test_lora_save_and_resume_tp2: End-to-end save/resume test with TP=2 - test_lora_weights_preserved_after_save_load_tp2: Explicit verification that loaded adapter weights exactly match saved weights on all TP ranks The second test specifically catches the replica_id bug by: 1. Capturing adapter weights before checkpoint save 2. Loading checkpoint into fresh model 3. Comparing loaded vs saved weights 4. Failing with clear error if loaded weights are all zeros (bug symptom)
|
/ok to test 8094b2b |
| return pretrain_checkpoint_dir, pretrain_tensorboard_dir, lora_checkpoint_dir, lora_tensorboard_dir | ||
|
|
||
|
|
||
| class TestLoRAFinetuneTP2: |
There was a problem hiding this comment.
nit: This seems to duplicate all the helper methods from TestLoRAFinetune above. Could share them via a base class or mixin to reduce the ~300 lines of boilerplate? Not a blocker.
Summary
When using LoRA adapters on dense layers (non-MoE) with TP > 1, only one TP shard was being saved during checkpointing. This caused significant training-inference loss discrepancy because other TP ranks loaded zero/uninitialized adapter weights.
Root Cause
In
ParallelLinearAdapter.sharded_state_dict(), there was specialreplica_idhandling for expert adapters (added in PR #1564, related to verl-project/verl#4303), but dense adapters never received the equivalent fix.For dense adapters with TP > 1:
replica_idglobal_offset, but during the PEFT-filtered checkpoint save, TP shards are incorrectly deduplicatedSolution
Add
replica_idhandling for dense adapters similar to expert adapters. When TP > 1, thereplica_idis adjusted to include the TP rank, ensuring each TP shard is correctly identified and saved during PEFT-filtered checkpoint saves.Changes
replica_idadjustment for dense adapters with TP > 1 inParallelLinearAdapter.sharded_state_dict()Testing
Summary by CodeRabbit
Bug Fixes
Tests