[https://nvbugs/5879577][fix] Fix KeyError in DeepSeekV3Lite FP8 MTP weight loading#12530
[https://nvbugs/5879577][fix] Fix KeyError in DeepSeekV3Lite FP8 MTP weight loading#12530sunnyqgg wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
…weight loading When model_nextn > ckpt_nextn (e.g., model requests 2 MTP layers but checkpoint provides 1), multiple model MTP layers map to the same checkpoint layer via modulo remapping. ConsumableWeightsDict.mark_consumed() deletes checkpoint weights after the first MTP layer loads them, causing KeyError when subsequent MTP layers try to load the same weights. Skip mark_consumed for MTP layers with shared checkpoint weights to prevent premature deletion of weights needed by later MTP layers. Signed-off-by: qgai <qgai@nvidia.com>
…P8 MTP Remove the test waiver now that the underlying KeyError is fixed. Signed-off-by: qgai <qgai@nvidia.com>
📝 WalkthroughWalkthroughModified DeepseekV3 weight loading to detect and handle MTP checkpoint-weight sharing by comparing num_nextn_predict_layers between checkpoint and spec config. When shared MTP weights are detected, module indices are remapped via modulo and weight consumption is skipped for affected layers. A test waive entry is removed. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
/bot run --extra-stage "DGX_H100_PCIe-PyTorch-Post-Merge-2" |
|
PR_Github #40291 [ run ] triggered by Bot. Commit: |
|
PR_Github #40291 [ run ] completed with state
|
Summary
KeyError: 'model.layers.30.self_attn.kv_a_proj_with_mqa.weight'when loading DeepSeekV3Lite FP8 with vanilla MTP (num_nextn_predict_layers=2)ConsumableWeightsDict.mark_consumed()deletes checkpoint weights after the first MTP layer loads them, but whenmodel_nextn > ckpt_nextn, subsequent MTP layers share the same checkpoint weights via modulo remapping and fail with KeyErrormark_consumedfor MTP layers when checkpoint weights are shared across multiple model MTP layerswaives.txtTest plan
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]on H100/bot run --extra-stage "DGX_H100_PCIe-PyTorch-Post-Merge-2"Summary by CodeRabbit
New Features
Tests