Skip to content

Commit 7187afe

Browse files
authored
[https://nvbugs/5781589][fix] Skip spec dec for non-last rank (#10445)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
1 parent e8cceb0 commit 7187afe

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,14 @@ def forward(
953953
hidden_states = hidden_states[:attn_metadata.num_tokens]
954954

955955
if self.draft_model is not None:
956+
# For one-model speculative decoding with PP, only the last PP rank
957+
# has valid hidden_states from the target model. The spec_worker (which
958+
# runs the draft model loop) should only run on the last PP rank.
959+
# Non-last PP ranks return None and let the PP sync handle the results.
960+
mapping = self.model.model_config.mapping
961+
if mapping.has_pp() and not mapping.is_last_pp_rank():
962+
return None
963+
956964
# get logits
957965
logits = self.logits_processor.forward(
958966
hidden_states[spec_metadata.gather_ids],

0 commit comments

Comments
 (0)