Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,20 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
# Check if weights supports mark_consumed (ConsumableWeightsDict)
can_mark_consumed = hasattr(weights, 'mark_consumed')

# Detect if MTP layers share checkpoint weights (model requests more
# MTP layers than the checkpoint provides). In this case, multiple
# model MTP layers map to the same checkpoint layer via modulo, and
# mark_consumed must be skipped to avoid deleting weights that later
# MTP layers still need.
ckpt_nextn = self.config.num_nextn_predict_layers or 0
model_nextn = 0
spec_config = self.model_config.spec_config
if spec_config is not None and hasattr(
spec_config, 'spec_dec_mode'
) and spec_config.spec_dec_mode.is_mtp_one_model():
model_nextn = spec_config.num_nextn_predict_layers
has_shared_mtp_weights = model_nextn > ckpt_nextn > 0

for name, module in tqdm(all_named_modules.items(),
desc="Loading weights"):
if len(module._parameters) <= 0 or name.startswith("draft_model"):
Expand All @@ -342,8 +356,10 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
else:
names = name.split('.')
parent_module_name = '.'.join(names[:-1])
is_shared_mtp_layer = False
if "model.layers" in name and int(
names[2]) >= self.config.num_hidden_layers:
is_shared_mtp_layer = has_shared_mtp_weights
mtp_layer_idx = int(
names[2]) - self.config.num_hidden_layers
names[2] = str(mtp_layer_idx %
Expand Down Expand Up @@ -409,7 +425,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
).view(*attn_module.v_b_proj_dequant.shape).to(
attn_module.v_b_proj_dequant.dtype))
# Mark consumed kv_b_proj weights
if can_mark_consumed:
if can_mark_consumed and not is_shared_mtp_layer:
weights.mark_consumed(name)
elif names[-1] == "kv_a_proj_with_mqa":
nvfp4_fused_a = self.model_config.get_quant_config(
Expand Down Expand Up @@ -534,7 +550,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
0:fused_a_scale.shape[0]].copy_(fused_a_scale)
module.weight.data[0:fused_a.shape[0]].copy_(fused_a)
# Mark consumed kv_a_proj_with_mqa and q_a_proj weights
if can_mark_consumed:
if can_mark_consumed and not is_shared_mtp_layer:
parent_prefix = '.'.join(names[:-1])
weights.mark_consumed(
f"{parent_prefix}.kv_a_proj_with_mqa")
Expand All @@ -548,7 +564,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
weights))
module.load_weights(weights=module_weights)
# Mark consumed source weights (e.g., gate_proj, up_proj)
if can_mark_consumed:
if can_mark_consumed and not is_shared_mtp_layer:
for src_name in params_map[names[-1]]:
weights.mark_consumed('.'.join(names[:-1] +
[src_name]))
Expand All @@ -561,7 +577,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
})
module.load_weights(weights=[module_weights])
# Mark consumed experts weights
if can_mark_consumed:
if can_mark_consumed and not is_shared_mtp_layer:
weights.mark_consumed(name)
elif names[-1] == "backend" and isinstance(module, MoE):
# Special case: ConfigurableMoE.backend (TRTLLMGenFusedMoE)
Expand All @@ -579,7 +595,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
})
module.load_weights(weights=[module_weights])
# Mark consumed MoE weights using parent name
if can_mark_consumed:
if can_mark_consumed and not is_shared_mtp_layer:
weights.mark_consumed(parent_name)
elif names[-1] == "self_attn":
continue
Expand All @@ -593,7 +609,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
for n, p in module.named_parameters():
p.data.copy_(module_weights[n][:])
# Mark consumed weights
if can_mark_consumed:
if can_mark_consumed and not is_shared_mtp_layer:
weights.mark_consumed(name)


Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-tp4-cutl
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_1gpu[v1_kv_cache-True-True-triton-auto] SKIP (https://nvbugs/5864187)
accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-trtllm-auto] SKIP (https://nvbugs/5596343)
test_e2e.py::test_trtllm_multimodal_benchmark_serving SKIP (https://nvbugs/5864769)
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5879577)
unittest/_torch/thop/serial/test_moe.py::TestMoeFp4::test_gptoss_style_nvfp4[limitinf-beta0-alpha0.1-RoutingGPTOSS-512-512-1] SKIP (https://nvbugs/5819042)
unittest/_torch/flashinfer/test_trtllm_flashinfer_symbol_collision.py::test_flashinfer_fused_moe_matches_torch_moe SKIP (https://nvbugs/5920779)
test_e2e.py::test_openai_chat_guided_decoding[openai/gpt-oss-120b] SKIP (https://nvbugs/5884677)
Expand Down
Loading