Skip to content

Commit 6273e89

Browse files
Add last_token_pos in llama_transformer (pytorch#11793)
Summary: Add last_token_pos in the forward options. Purpose: * the last norm and output of lm-head can be performed with the last valid token at prefill. * If the input sequence length is fixed when an accelerator doesn't support the dynamic shapes, selecting the last token from the input is not always guaranteed as valid. * Thus, it needs an additional pointer to select the last valid token only to perform the last norm and output. Reviewed By: JacobSzwejbka Differential Revision: D76440105
1 parent 851b29b commit 6273e89

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

examples/models/llama/attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class ForwardOptions(TypedDict, total=False):
1919
freqs_sin_override: Optional[torch.Tensor]
2020
in_cache_state: Optional[Any]
2121
out_cache_state: Optional[Any]
22+
last_valid_token_pos: Optional[torch.LongTensor]
2223

2324

2425
class Attention(nn.Module, ABC):

examples/models/llama/llama_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ def forward(
204204

205205
if not self.generate_full_logits:
206206
# Only the last logit is used for the new generated token
207-
h = h[:, -1, :]
207+
if attn_options.get("last_valid_token_pos", None):
208+
h = h[:, attn_options.get("last_valid_token_pos"), :]
209+
else:
210+
h = h[:, -1, :]
208211

209212
h = self.norm(h)
210213

0 commit comments

Comments
 (0)