diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 63d783c3332..aa53b330837 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -19,6 +19,7 @@ class ForwardOptions(TypedDict, total=False): freqs_sin_override: Optional[torch.Tensor] in_cache_state: Optional[Any] out_cache_state: Optional[Any] + last_valid_token_pos: Optional[torch.LongTensor] class Attention(nn.Module, ABC): diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 1fdcdcd91fc..1eb28c1ad5f 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -204,7 +204,10 @@ def forward( if not self.generate_full_logits: # Only the last logit is used for the new generated token - h = h[:, -1, :] + if attn_options.get("last_valid_token_pos", None): + h = h[:, attn_options.get("last_valid_token_pos"), :] + else: + h = h[:, -1, :] h = self.norm(h)