We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e8cceb0 commit 7187afeCopy full SHA for 7187afe
tensorrt_llm/_torch/models/modeling_speculative.py
@@ -953,6 +953,14 @@ def forward(
953
hidden_states = hidden_states[:attn_metadata.num_tokens]
954
955
if self.draft_model is not None:
956
+ # For one-model speculative decoding with PP, only the last PP rank
957
+ # has valid hidden_states from the target model. The spec_worker (which
958
+ # runs the draft model loop) should only run on the last PP rank.
959
+ # Non-last PP ranks return None and let the PP sync handle the results.
960
+ mapping = self.model.model_config.mapping
961
+ if mapping.has_pp() and not mapping.is_last_pp_rank():
962
+ return None
963
+
964
# get logits
965
logits = self.logits_processor.forward(
966
hidden_states[spec_metadata.gather_ids],
0 commit comments