Skip to content

Commit 63b070d

Browse files
committed
Update on "Print the number of tokens generated"
This is useful for verifying the correctness of AttentionSink. Differential Revision: [D65784095](https://our.internmc.facebook.com/intern/diff/D65784095/) [ghstack-poisoned]
1 parent 82f8713 commit 63b070d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

examples/models/llama/runner/generation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def forward(
6464
def generate( # noqa: C901
6565
self,
6666
prompt_tokens: List[int],
67+
max_seq_len: int,
6768
temperature: float = 0.8,
6869
top_p: float = 0.9,
6970
echo: bool = False,
@@ -83,7 +84,7 @@ def generate( # noqa: C901
8384
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
8485
tokens = prompt_tokens + [current_token]
8586

86-
while len(tokens) < self.params.max_seq_len:
87+
while len(tokens) < max_seq_len:
8788
if self.params.use_kv_cache:
8889
logits = self.forward(
8990
tokens=torch.tensor(
@@ -135,6 +136,7 @@ def text_completion(
135136
"""
136137
return self.generate(
137138
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
139+
max_seq_len=self.params.max_seq_len,
138140
temperature=temperature,
139141
top_p=top_p,
140142
echo=echo,
@@ -169,6 +171,7 @@ def chat_completion(
169171
prompt_tokens=self.tokenizer.encode(
170172
self._format_prompt(prompt), bos=True, eos=False
171173
),
174+
max_seq_len=self.params.max_seq_len,
172175
temperature=temperature,
173176
top_p=top_p,
174177
echo=True,

0 commit comments

Comments
 (0)