Skip to content

Commit 7bfe3b9

Browse files
Fix Cuda out of memory issue for eager runner (pytorch#6935)
Pull Request resolved: pytorch#6866 This PR updates the eager runner to disable grad and save memory usage. It also update the prompt format to not include bos. ghstack-source-id: 254139542 Differential Revision: [D65962743](https://our.internmc.facebook.com/intern/diff/D65962743/) Co-authored-by: Lunwen He <[email protected]>
1 parent 8013822 commit 7bfe3b9

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

examples/models/llama/runner/eager.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,20 @@ def build_args_parser() -> argparse.ArgumentParser:
8080
def execute_runner(runner_class: Type[LlamaRunner]) -> None:
8181
parser = build_args_parser()
8282
args = parser.parse_args()
83-
runner = runner_class(args) # pyre-ignore: Missing argument [20]
84-
generated_tokens = (
85-
runner.chat_completion(temperature=args.temperature)
86-
if args.chat
87-
else runner.text_completion(
88-
prompt=args.prompt,
89-
temperature=args.temperature,
90-
echo=True,
83+
84+
with torch.no_grad():
85+
runner = runner_class(args) # pyre-ignore: Missing argument [20]
86+
generated_tokens = (
87+
runner.chat_completion(temperature=args.temperature)
88+
if args.chat
89+
else runner.text_completion(
90+
prompt=args.prompt,
91+
temperature=args.temperature,
92+
echo=True,
93+
)
9194
)
92-
)
93-
if args.show_tokens:
94-
print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
95+
if args.show_tokens:
96+
print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
9597

9698

9799
def main() -> None:

examples/models/llama/runner/generation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,14 @@ def chat_completion(
199199
temperature=temperature,
200200
top_p=top_p,
201201
echo=True,
202-
pos_base=len(tokens),
202+
pos_base=len(tokens) - 1 if len(tokens) > 0 else 0,
203203
)
204204
tokens.extend(new_tokens)
205205
prompt = input("Me: ")
206206
return tokens
207207

208208
def _format_prompt(self, prompt: str) -> str:
209-
return f"""
210-
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
209+
return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
211210
212211
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
213212

0 commit comments

Comments
 (0)