Skip to content

Commit d17e678

Browse files
tryrobboThomas Robinson
andauthored
Add Llama 3.1 example upgrade script (meta-llama#5)
* Add example upgrade script * Modify modelUpgradeExample.py as per suggestions * Reference files in README --------- Co-authored-by: Thomas Robinson <[email protected]>
1 parent a1d51de commit d17e678

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

recipes/quickstart/inference/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
This folder contains scripts to get you started with inference on Meta Llama models.
44

5-
* [](./code_llama/) contains scripts for tasks relating to code generation using CodeLlama
6-
* [](./local_inference/) contsin scripts to do memory efficient inference on servers and local machines
7-
* [](./mobile_inference/) has scripts using MLC to serve Llama on Android (h/t to OctoAI for the contribution!)
5+
* [Code Llama](./code_llama/) contains scripts for tasks relating to code generation using CodeLlama
6+
* [Local Inference](./local_inference/) contains scripts to do memory efficient inference on servers and local machines
7+
* [Mobile Inference](./mobile_inference/) has scripts using MLC to serve Llama on Android (h/t to OctoAI for the contribution!)
8+
* [Model Update Example](./modelUpgradeExample.py) shows an example of replacing a Llama 3 model with a Llama 3.1 model.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3+
4+
# Running the script without any arguments "python modelUpgradeExample.py" performs inference with the Llama 3 8B Instruct model.
5+
# Passing --model-id "meta-llama/Meta-Llama-3.1-8B-Instruct" to the script will switch it to using the Llama 3.1 version of the same model.
6+
# The script also shows the input tokens to confirm that the models are responding to the same input
7+
8+
import fire
9+
from transformers import AutoTokenizer, AutoModelForCausalLM
10+
import torch
11+
12+
def main(model_id = "meta-llama/Meta-Llama-3-8B-Instruct"):
13+
tokenizer = AutoTokenizer.from_pretrained(model_id)
14+
model = AutoModelForCausalLM.from_pretrained(
15+
model_id,
16+
torch_dtype=torch.bfloat16,
17+
device_map="auto",
18+
)
19+
20+
messages = [
21+
{"role": "system", "content": "You are a helpful chatbot"},
22+
{"role": "user", "content": "Why is the sky blue?"},
23+
{"role": "assistant", "content": "Because the light is scattered"},
24+
{"role": "user", "content": "Please tell me more about that"},
25+
]
26+
27+
input_ids = tokenizer.apply_chat_template(
28+
messages,
29+
add_generation_prompt=True,
30+
return_tensors="pt",
31+
).to(model.device)
32+
33+
print("Input tokens:")
34+
print(input_ids)
35+
36+
attention_mask = torch.ones_like(input_ids)
37+
outputs = model.generate(
38+
input_ids,
39+
max_new_tokens=400,
40+
eos_token_id=tokenizer.eos_token_id,
41+
do_sample=True,
42+
temperature=0.6,
43+
top_p=0.9,
44+
attention_mask=attention_mask,
45+
)
46+
response = outputs[0][input_ids.shape[-1]:]
47+
print("\nOutput:\n")
48+
print(tokenizer.decode(response, skip_special_tokens=True))
49+
50+
if __name__ == "__main__":
51+
fire.Fire(main)

0 commit comments

Comments
 (0)