diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 63d783c3332..aa53b330837 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -19,6 +19,7 @@ class ForwardOptions(TypedDict, total=False): freqs_sin_override: Optional[torch.Tensor] in_cache_state: Optional[Any] out_cache_state: Optional[Any] + last_valid_token_pos: Optional[torch.LongTensor] class Attention(nn.Module, ABC): diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 1fdcdcd91fc..a53e1716375 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -204,7 +204,8 @@ def forward( if not self.generate_full_logits: # Only the last logit is used for the new generated token - h = h[:, -1, :] + pos = attn_options.get("last_valid_token_pos", -1) + h = h[:, pos, :] h = self.norm(h)