From ce3d29bdbfb7a99f654f1cbbe13e1680402d68e2 Mon Sep 17 00:00:00 2001 From: Jinook Song Date: Mon, 7 Jul 2025 16:13:57 -0700 Subject: [PATCH] Add last_token_pos in llama_transformer (#12239) Summary: Add last_token_pos in the forward options. Purpose: * the last norm and output of lm-head can be performed with the last valid token at prefill. * If the input sequence length is fixed when an accelerator doesn't support the dynamic shapes, selecting the last token from the input is not always guaranteed as valid. * Thus, it needs an additional pointer to select the last valid token only to perform the last norm and output. Reviewed By: JacobSzwejbka Differential Revision: D76440105 --- examples/models/llama/attention.py | 1 + examples/models/llama/llama_transformer.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 63d783c3332..aa53b330837 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -19,6 +19,7 @@ class ForwardOptions(TypedDict, total=False): freqs_sin_override: Optional[torch.Tensor] in_cache_state: Optional[Any] out_cache_state: Optional[Any] + last_valid_token_pos: Optional[torch.LongTensor] class Attention(nn.Module, ABC): diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 1fdcdcd91fc..a53e1716375 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -204,7 +204,8 @@ def forward( if not self.generate_full_logits: # Only the last logit is used for the new generated token - h = h[:, -1, :] + pos = attn_options.get("last_valid_token_pos", -1) + h = h[:, pos, :] h = self.norm(h)