[https://nvbugs/5916092][fix] Fix MTP+PP hang by preserving speculative layer weights on last PP rank#12555
[https://nvbugs/5916092][fix] Fix MTP+PP hang by preserving speculative layer weights on last PP rank#12555xxi-nv wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
…ve layer weights on last PP rank DecoderModel.__pp_init__ iterates all layers including MTP speculative layers appended via layers.extend(mtp_layers). Since MTP layer indices exceed num_hidden_layers, they are not in pp_layer_list and get skip_forward() on ALL ranks, which replaces forward with a no-op AND removes weights. On non-last PP ranks this is correct (MTP layers unused). But on the last PP rank, the MTP draft worker needs the layer weights. Removing them causes a hang during generation. Fix: For layers beyond num_hidden_layers, always replace forward with skip_forward (so they are no-ops in the main decoder loop on all ranks), but only remove weights on non-last PP ranks. The last PP rank preserves weights so the MTP speculative decoding worker can use them. Affects all models using layers.extend(mtp_layers): DeepSeekV3, NemotronH, ExaoneMoE, and GLM. Signed-off-by: xxi <xxi@nvidia.com>
|
/bot run |
📝 WalkthroughWalkthroughModified pipeline-parallel layer initialization in Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_utils.py (1)
307-316: Good fix for the MTP+PP hang issue.The logic correctly preserves weights on the last PP rank while making extra layers no-ops in the main decoder loop. The separation of concerns (skip forward vs. remove weights) is well-reasoned.
Consider adding a warning when
skip_forwardis unavailable. If a layer doesn't haveskip_forward, its forward remains unchanged. On the last PP rank, this means the layer would execute normally in the main decoder loop instead of being a no-op. While unlikely with MTP layers (DecoderLayer instances haveskip_forward), a warning would be consistent with the existingskip_forwardfunction at lines 163-165 and help debug unexpected behavior.Suggested defensive warning
if layer_idx >= num_hidden_layers: # Extra layers (e.g., MTP speculative layers) appended beyond # the base model. Skip their forward on all ranks so they are # no-ops in the main decoder loop, but preserve weights on the # last PP rank where the MTP draft worker needs them. if hasattr(layer, 'skip_forward'): layer.forward = layer.skip_forward + else: + logger.warning( + f"Layer {layer_idx} ({layer.__class__.__name__}) does not have " + f"`skip_forward`; it will not be a no-op in the main decoder loop.") if not mapping.is_last_pp_rank(): remove_weights(layer) continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/models/modeling_utils.py` around lines 307 - 316, When handling extra layers (layer_idx >= num_hidden_layers) add a defensive warning if a layer does not have skip_forward: if hasattr(layer, 'skip_forward') is false, emit a warning (e.g., logging.warning or warnings.warn) that the extra layer lacks skip_forward so its forward will remain active on the last PP rank; keep the existing behavior of calling layer.skip_forward when present and remove_weights when not mapping.is_last_pp_rank(), but log this unexpected condition referencing layer_idx, the layer object, and mapping.is_last_pp_rank() to aid debugging (mirror the existing skip_forward handling used elsewhere).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tensorrt_llm/_torch/models/modeling_utils.py`:
- Around line 307-316: When handling extra layers (layer_idx >=
num_hidden_layers) add a defensive warning if a layer does not have
skip_forward: if hasattr(layer, 'skip_forward') is false, emit a warning (e.g.,
logging.warning or warnings.warn) that the extra layer lacks skip_forward so its
forward will remain active on the last PP rank; keep the existing behavior of
calling layer.skip_forward when present and remove_weights when not
mapping.is_last_pp_rank(), but log this unexpected condition referencing
layer_idx, the layer object, and mapping.is_last_pp_rank() to aid debugging
(mirror the existing skip_forward handling used elsewhere).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 34356177-7950-4e51-9c55-a9da2695250f
📒 Files selected for processing (2)
tensorrt_llm/_torch/models/modeling_utils.pytests/integration/test_lists/waives.txt
💤 Files with no reviewable changes (1)
- tests/integration/test_lists/waives.txt
|
PR_Github #40435 [ run ] triggered by Bot. Commit: |
|
PR_Github #40435 [ run ] completed with state
|
Summary
DecoderModel.__pp_init__()when both pipeline parallelism (PP) and MTP speculative decoding are enablednum_hidden_layerswere incorrectly gettingskip_forward()+remove_weights()on all PP ranks, including the last rank where the MTP draft worker needs themnum_hidden_layers: always make them no-ops in the main decoder loop (all ranks), but only remove weights on non-last PP ranksAffects: DeepSeekV3, NemotronH, ExaoneMoE, GLM (all use
self.model.layers.extend(self.draft_model.mtp_layers))Root cause
__pp_init__iterates ALLself.layers, including MTP layers at indices >=num_hidden_layers. Since these indices are not inpp_layer_list, the existing code callsskip_forward()which:layer.forwardwith a no-opremove_weights()to free GPU memoryThis happens on ALL ranks, including the last PP rank. The MTP draft worker (running on rank N-1) then finds its speculative layers have no weights → hang.
Fix
Test plan
TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-...]should pass (was hanging)waives.txtSummary by CodeRabbit
Bug Fixes
Tests