Skip to content

Commit 8a371a7

Browse files
committed
Fix Cuda out of memory issue for eager runner
This PR updates the eager runner to disable grad and save memory usage. It also update the prompt format to not include bos. Differential Revision: [D65962743](https://our.internmc.facebook.com/intern/diff/D65962743/) [ghstack-poisoned]
1 parent ecdc007 commit 8a371a7

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

examples/models/llama/runner/eager.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,19 @@ def main() -> None:
8484
parser = build_args_parser()
8585
args = parser.parse_args()
8686

87-
runner = EagerLlamaRunner(args)
88-
generated_tokens = (
89-
runner.chat_completion(temperature=args.temperature)
90-
if args.chat
91-
else runner.text_completion(
92-
prompt=args.prompt,
93-
temperature=args.temperature,
94-
echo=True,
87+
with torch.no_grad():
88+
runner = EagerLlamaRunner(args)
89+
generated_tokens = (
90+
runner.chat_completion(temperature=args.temperature)
91+
if args.chat
92+
else runner.text_completion(
93+
prompt=args.prompt,
94+
temperature=args.temperature,
95+
echo=True,
96+
)
9597
)
96-
)
97-
if args.show_tokens:
98-
print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
98+
if args.show_tokens:
99+
print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
99100

100101

101102
if __name__ == "__main__":

examples/models/llama/runner/generation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def text_completion(
135135
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
136136
"""
137137
return self.generate(
138-
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
138+
prompt_tokens=self.tokenizer.encode(prompt, bos=False, eos=False),
139139
max_seq_len=self.params.max_seq_len,
140140
temperature=temperature,
141141
top_p=top_p,
@@ -169,7 +169,7 @@ def chat_completion(
169169
print("LLM: ", end="", flush=True)
170170
new_tokens = self.generate(
171171
prompt_tokens=self.tokenizer.encode(
172-
self._format_prompt(prompt), bos=True, eos=False
172+
self._format_prompt(prompt), bos=False, eos=False
173173
),
174174
max_seq_len=self.params.max_seq_len,
175175
temperature=temperature,
@@ -182,8 +182,7 @@ def chat_completion(
182182
return tokens
183183

184184
def _format_prompt(self, prompt: str) -> str:
185-
return f"""
186-
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
185+
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
187186
188187
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
189188

0 commit comments

Comments
 (0)