Skip to content

Commit 042aac7

Browse files
authored
Fix llama 3 (#1671)
* fix llama 3 * fix black
1 parent 3ba798e commit 042aac7

File tree

4 files changed

+139
-1
lines changed

4 files changed

+139
-1
lines changed

examples/llama3/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Chat with Llama 3
2+
3+
In addition to the Llama 2 example, this is an example how to implement an interactive chat session with Llama 3. This script updates the template used by Llama 3.
4+
5+
## Installation
6+
You can follow the README in llama2 to setup the environment.
7+
8+
## Start a chat session
9+
```
10+
python3 chat.py llama-3-7b-chat-ct2/
11+
```
12+
13+
You can also set a system prompt on the command line:
14+
15+
```
16+
python3 chat.py llama-3-7b-chat-ct2/ ["System prompt..."]
17+
```

examples/llama3/chat.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import os
2+
import sys
3+
4+
import ctranslate2
5+
from transformers import AutoTokenizer
6+
7+
8+
def main():
9+
model_dir = sys.argv[1]
10+
system_prompt = sys.argv[2] if len(sys.argv) > 2 else None
11+
12+
print("Loading the model...")
13+
generator = ctranslate2.Generator(model_dir, device="cuda")
14+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
15+
16+
context_length = 4096
17+
max_generation_length = 512
18+
max_prompt_length = context_length - max_generation_length
19+
20+
dialog = []
21+
22+
if system_prompt:
23+
dialog.append({"role": "system", "content": system_prompt})
24+
25+
while True:
26+
print("")
27+
28+
user_prompt = input("You: ")
29+
30+
dialog.append({"role": "user", "content": user_prompt})
31+
32+
while True:
33+
prompt_tokens = build_prompt(tokenizer, dialog)
34+
if len(prompt_tokens) <= max_prompt_length:
35+
break
36+
# Remove old conversations to reduce the prompt size.
37+
if system_prompt:
38+
dialog = [dialog[0]] + dialog[3:]
39+
else:
40+
dialog = dialog[2:]
41+
42+
step_results = generator.generate_tokens(
43+
prompt_tokens,
44+
max_length=max_generation_length,
45+
sampling_temperature=0.6,
46+
sampling_topk=20,
47+
sampling_topp=1,
48+
)
49+
50+
print("")
51+
print("Llama3: ", end="", flush=True)
52+
53+
text_output = ""
54+
55+
for word in generate_words(tokenizer, step_results):
56+
print(word, end="", flush=True)
57+
text_output += word
58+
59+
print("")
60+
61+
dialog.append({"role": "assistant", "content": text_output.strip()})
62+
63+
64+
def generate_words(tokenizer, step_results):
65+
tokens_buffer = []
66+
67+
for step_result in step_results:
68+
is_new_word = step_result.token.startswith("Ġ")
69+
70+
if is_new_word and tokens_buffer:
71+
word = tokenizer.decode(tokens_buffer)
72+
if word:
73+
yield word
74+
tokens_buffer = []
75+
76+
tokens_buffer.append(step_result.token_id)
77+
78+
if tokens_buffer:
79+
word = tokenizer.decode(tokens_buffer)
80+
if word:
81+
yield word
82+
83+
84+
B_ID, E_ID, E_INST = "<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"
85+
86+
87+
def build_prompt(tokenizer, dialog):
88+
begin_pos = 0
89+
if dialog[0]["role"] == "system":
90+
begin_pos = 1
91+
assert all([msg["role"] == "user" for msg in dialog[begin_pos::2]]) and all(
92+
[msg["role"] == "assistant" for msg in dialog[begin_pos + 1::2]]
93+
), (
94+
"model only supports 'system', 'user' and 'assistant' roles, "
95+
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
96+
)
97+
98+
dialog_tokens = sum([
99+
tokenizer.tokenize(
100+
f"{B_ID} {(item['role'])} {E_ID} {(item['content']).strip()} {E_INST}"
101+
)
102+
for item in dialog
103+
], [])
104+
dialog_tokens = ["<|begin_of_text|>"] + dialog_tokens + tokenizer.tokenize(
105+
f"{B_ID} assistant {E_ID}"
106+
)
107+
108+
assert (
109+
dialog[-1]["role"] == "user"
110+
), f"Last message must be from user, got {dialog[-1]['role']}"
111+
112+
return dialog_tokens
113+
114+
115+
if __name__ == "__main__":
116+
main()

examples/llama3/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
ctranslate2>=4.3.0
2+
transformers[torch]==4.40.*
3+
accelerate

python/ctranslate2/converters/transformers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1436,7 +1436,9 @@ def set_vocabulary(self, spec, tokens):
14361436
def set_config(self, config, model, tokenizer):
14371437
config.bos_token = tokenizer.bos_token
14381438
config.eos_token = tokenizer.eos_token
1439-
config.unk_token = tokenizer.unk_token
1439+
config.unk_token = (
1440+
tokenizer.unk_token if tokenizer.unk_token is not None else ""
1441+
)
14401442
config.layer_norm_epsilon = model.config.rms_norm_eps
14411443

14421444
def set_layer_norm(self, spec, layer_norm):

0 commit comments

Comments
 (0)