Skip to content

Commit 9f52006

Browse files
committed
Fix tests for custom dataset, grammar, batching, chat_completion
1 parent 8b01298 commit 9f52006

File tree

4 files changed

+23
-77
lines changed

4 files changed

+23
-77
lines changed

src/tests/datasets/test_custom_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def check_padded_entry(batch, tokenizer):
3737
@pytest.mark.skip_missing_tokenizer
3838
@patch('llama_recipes.finetuning.train')
3939
@patch('llama_recipes.finetuning.AutoTokenizer')
40-
@patch('llama_recipes.finetuning.AutoModel.from_pretrained')
40+
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
4141
@patch('llama_recipes.finetuning.optim.AdamW')
4242
@patch('llama_recipes.finetuning.StepLR')
4343
def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
@@ -97,7 +97,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
9797

9898
@patch('llama_recipes.finetuning.train')
9999
@patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
100-
@patch('llama_recipes.finetuning.AutoModel.from_pretrained')
100+
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
101101
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
102102
@patch('llama_recipes.finetuning.optim.AdamW')
103103
@patch('llama_recipes.finetuning.StepLR')

src/tests/datasets/test_grammar_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
@patch('llama_recipes.finetuning.train')
1414
@patch('llama_recipes.finetuning.AutoTokenizer')
1515
@patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
16-
@patch('llama_recipes.finetuning.AutoModel.from_pretrained')
16+
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
1717
@patch('llama_recipes.finetuning.optim.AdamW')
1818
@patch('llama_recipes.finetuning.StepLR')
1919
def test_grammar_dataset(step_lr, optimizer, get_model, get_config, tokenizer, train, setup_tokenizer, llama_version):

src/tests/test_batching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"train": 96,
1010
"eval": 42,
1111
},
12-
"meta-llama/Meta-Llama-3.1-8B": {
12+
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
1313
"train": 79,
1414
"eval": 34,
1515
}

src/tests/test_chat_completion.py

Lines changed: 19 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,64 @@
11
import sys
22
from pathlib import Path
3-
from typing import List, Literal, TypedDict
3+
from typing import List, TypedDict
44
from unittest.mock import patch
55

66
import pytest
77
import torch
88
from llama_recipes.inference.chat_utils import read_dialogs_from_file
99

1010
ROOT_DIR = Path(__file__).parents[2]
11-
CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
11+
CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/quickstart/inference/local_inference/chat_completion/"
1212

1313
sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
1414

15-
Role = Literal["user", "assistant"]
16-
17-
18-
class Message(TypedDict):
19-
role: Role
20-
content: str
21-
22-
23-
Dialog = List[Message]
24-
25-
B_INST, E_INST = "[INST]", "[/INST]"
26-
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
27-
15+
default_system_prompt = [{"role": "system", "content": "Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"}]
2816

2917
def _encode_header(message, tokenizer):
3018
tokens = []
31-
tokens.extend(tokenizer.encode("<|start_header_id|>"))
32-
tokens.extend(tokenizer.encode(message["role"]))
33-
tokens.extend(tokenizer.encode("<|end_header_id|>"))
34-
tokens.extend(tokenizer.encode("\n\n"))
19+
tokens.extend(tokenizer.encode("<|start_header_id|>", add_special_tokens=False))
20+
tokens.extend(tokenizer.encode(message["role"], add_special_tokens=False))
21+
tokens.extend(tokenizer.encode("<|end_header_id|>", add_special_tokens=False))
22+
tokens.extend(tokenizer.encode("\n\n", add_special_tokens=False))
3523
return tokens
3624

3725

3826
def _encode_message(message, tokenizer):
3927
tokens = _encode_header(message, tokenizer)
40-
tokens.extend(tokenizer.encode(message["content"].strip()))
41-
tokens.extend(tokenizer.encode("<|eot_id|>"))
28+
tokens.extend(tokenizer.encode(message["content"], add_special_tokens=False))
29+
tokens.extend(tokenizer.encode("<|eot_id|>", add_special_tokens=False))
4230
return tokens
4331

4432

4533
def _format_dialog(dialog, tokenizer):
4634
tokens = []
47-
tokens.extend(tokenizer.encode("<|begin_of_text|>"))
35+
tokens.extend(tokenizer.encode("<|begin_of_text|>", add_special_tokens=False))
36+
if dialog[0]["role"] == "system":
37+
dialog[0]["content"] = default_system_prompt[0]["content"] + dialog[0]["content"]
38+
else:
39+
dialog = default_system_prompt + dialog
4840
for msg in dialog:
4941
tokens.extend(_encode_message(msg, tokenizer))
50-
tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
5142
return tokens
5243

5344

5445
def _format_tokens_llama3(dialogs, tokenizer):
5546
return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
5647

5748

58-
def _format_tokens_llama2(dialogs, tokenizer):
59-
prompt_tokens = []
60-
for dialog in dialogs:
61-
if dialog[0]["role"] == "system":
62-
dialog = [
63-
{
64-
"role": dialog[1]["role"],
65-
"content": B_SYS
66-
+ dialog[0]["content"]
67-
+ E_SYS
68-
+ dialog[1]["content"],
69-
}
70-
] + dialog[2:]
71-
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
72-
[msg["role"] == "assistant" for msg in dialog[1::2]]
73-
), (
74-
"model only supports 'system','user' and 'assistant' roles, "
75-
"starting with user and alternating (u/a/u/a/u...)"
76-
)
77-
"""
78-
Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
79-
Here, we are adding it manually.
80-
"""
81-
dialog_tokens: List[int] = sum(
82-
[
83-
tokenizer.encode(
84-
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
85-
)
86-
+ [tokenizer.eos_token_id]
87-
for prompt, answer in zip(dialog[::2], dialog[1::2])
88-
],
89-
[],
90-
)
91-
assert (
92-
dialog[-1]["role"] == "user"
93-
), f"Last message must be from user, got {dialog[-1]['role']}"
94-
dialog_tokens += tokenizer.encode(
95-
f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
96-
)
97-
prompt_tokens.append(dialog_tokens)
98-
return prompt_tokens
99-
100-
10149
@pytest.mark.skip_missing_tokenizer
10250
@patch("chat_completion.AutoTokenizer")
10351
@patch("chat_completion.load_model")
10452
def test_chat_completion(
10553
load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
10654
):
55+
if "Llama-2" in llama_version:
56+
pytest.skip("skipping test for Llama-2")
57+
10758
from chat_completion import main
10859

10960
setup_tokenizer(tokenizer)
110-
load_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
61+
load_model.return_value.get_input_embeddings.return_value.weight.shape = [128256]
11162

11263
kwargs = {
11364
"prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
@@ -116,13 +67,8 @@ def test_chat_completion(
11667
main(llama_version, **kwargs)
11768

11869
dialogs = read_dialogs_from_file(kwargs["prompt_file"])
119-
format_tokens = (
120-
_format_tokens_llama2
121-
if llama_version == "meta-llama/Llama-2-7b-hf"
122-
else _format_tokens_llama3
123-
)
12470

125-
REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
71+
REF_RESULT = _format_tokens_llama3(dialogs, llama_tokenizer[llama_version])
12672

12773
assert all(
12874
(

0 commit comments

Comments
 (0)