Skip to content

Commit 64d3437

Browse files
committed
Update eager runner to support AttentionSink
This PR updates the eager runner to support AttentionSink. It also fixes issues in the `chat_completion` function to properly handle the position id. Differential Revision: [D66076486](https://our.internmc.facebook.com/intern/diff/D66076486/) [ghstack-poisoned]
1 parent e6d16be commit 64d3437

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

examples/models/llama/runner/eager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None:
8484
with torch.no_grad():
8585
runner = runner_class(args) # pyre-ignore: Missing argument [20]
8686
generated_tokens = (
87-
runner.chat_completion(temperature=args.temperature)
87+
runner.chat_completion(
88+
max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length,
89+
temperature=args.temperature,
90+
show_progress=args.show_tokens,
91+
)
8892
if args.chat
8993
else runner.text_completion(
9094
prompt=args.prompt,

examples/models/llama/runner/generation.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,18 +168,19 @@ def text_completion(
168168

169169
def chat_completion(
170170
self,
171+
max_seq_len: int,
171172
temperature: float = 0.6,
172173
top_p: float = 0.9,
174+
show_progress: bool = False,
173175
) -> List[int]:
174176
"""
175177
Perform multi-turn chat with the language model.
176178
177179
Args:
178-
prompt (str): Text prompt for completion.
180+
max_seq_len (int): Maximum number of tokens to generate for each prompt.
179181
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
180182
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
181-
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
182-
183+
show_progress (bool, optional): Flag indicating whether to show number of tokens generated.
183184
Returns:
184185
Generated list of tokens.
185186
@@ -188,20 +189,26 @@ def chat_completion(
188189
"""
189190
exit_prompt = "exit"
190191
tokens = []
192+
pre_stop_token = []
191193
prompt = input("Me: ")
192194
while prompt and prompt != exit_prompt:
193195
print("LLM: ", end="", flush=True)
194-
new_tokens = self.generate(
195-
prompt_tokens=self.tokenizer.encode(
196-
self._format_prompt(prompt), bos=True, eos=False
197-
),
198-
max_seq_len=self.max_seq_len,
196+
prompt_tokens = self.tokenizer.encode(
197+
self._format_prompt(prompt), bos=True, eos=False
198+
)
199+
generated_tokens = self.generate(
200+
prompt_tokens=pre_stop_token + prompt_tokens,
201+
max_seq_len=max_seq_len,
199202
temperature=temperature,
200203
top_p=top_p,
201-
echo=True,
202-
pos_base=len(tokens) - 1 if len(tokens) > 0 else 0
204+
echo=False,
205+
pos_base=len(tokens) - 1 if len(tokens) > 0 else 0,
203206
)
204-
tokens.extend(new_tokens)
207+
pre_stop_token = generated_tokens[-1:]
208+
tokens.extend(prompt_tokens)
209+
tokens.extend(generated_tokens)
210+
if show_progress:
211+
print(f"[Generated {len(tokens)} tokens]")
205212
prompt = input("Me: ")
206213
return tokens
207214

0 commit comments

Comments
 (0)