Skip to content

Commit 1c7b7cd

Browse files
authored
[TRTLLM-9506][fix] Fix AR for DeepSeek-R1 2 model path (#9661)
Signed-off-by: qgai <qgai@nvidia.com>
1 parent 98db262 commit 1c7b7cd

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,13 +1233,13 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
12331233
hidden_states, residual = self.moe_allreduce(
12341234
fc2_output, all_reduce_params=moe_all_reduce_params)
12351235
else:
1236-
if spec_metadata is not None and spec_metadata.is_layer_capture(
1237-
self.layer_idx):
1238-
spec_metadata.maybe_capture_hidden_states(
1239-
self.layer_idx, hidden_states, residual)
12401236
if self.next_layer_layernorm is not None:
12411237
hidden_states, residual = self.next_layer_layernorm(
12421238
hidden_states, residual)
1239+
if spec_metadata is not None and spec_metadata.is_layer_capture(
1240+
self.layer_idx):
1241+
spec_metadata.maybe_capture_hidden_states(
1242+
self.layer_idx, hidden_states, None)
12431243

12441244
return hidden_states, residual
12451245

@@ -1357,6 +1357,7 @@ def forward(
13571357
embed_tokens: Embedding,
13581358
attn_metadata: AttentionMetadata,
13591359
all_rank_num_tokens: Optional[List[int]] = None,
1360+
spec_metadata: Optional[SpecMetadata] = None,
13601361
**kwargs,
13611362
) -> torch.Tensor:
13621363

@@ -1433,6 +1434,10 @@ def norm_hidden():
14331434
else:
14341435
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
14351436

1437+
# It's for 2-model path, capture the hidden states
1438+
if spec_metadata is not None:
1439+
spec_metadata.maybe_capture_hidden_states(0, hidden_states, None)
1440+
14361441
return hidden_states
14371442

14381443

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def forward(
455455
hidden_states: torch.Tensor,
456456
attn_metadata: AttentionMetadata,
457457
all_rank_num_tokens: Optional[List[int]] = None,
458+
spec_metadata: Optional[SpecMetadata] = None,
458459
**kwargs,
459460
) -> Tuple[torch.Tensor, torch.Tensor]:
460461
hidden_states = self.layers(
@@ -464,6 +465,7 @@ def forward(
464465
embed_tokens=self.embed_tokens,
465466
attn_metadata=attn_metadata,
466467
all_rank_num_tokens=all_rank_num_tokens,
468+
spec_metadata=spec_metadata,
467469
)
468470

469471
return hidden_states
@@ -518,6 +520,7 @@ def forward(self,
518520
hidden_states=hidden_states,
519521
attn_metadata=attn_metadata,
520522
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
523+
spec_metadata=spec_metadata,
521524
**kwargs)
522525
return self.logits_processor.forward(
523526
output,

0 commit comments

Comments
 (0)