Skip to content

Commit d45222e

Browse files
committed
Add unit test for tokenizer_dialog for custom dataset
1 parent dadadf8 commit d45222e

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

src/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from transformers import AutoTokenizer
77

88
ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
9-
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B"]
9+
LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
1010

1111
@pytest.fixture(params=LLAMA_VERSIONS)
1212
def llama_version(request):

src/tests/datasets/test_custom_dataset.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"example_1": "[INST] Who made Berlin [/INST] dunno",
1212
"example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
1313
},
14-
"meta-llama/Meta-Llama-3.1-8B":{
14+
"meta-llama/Meta-Llama-3.1-8B-Instruct":{
1515
"example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
1616
"example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
1717
},
@@ -114,3 +114,35 @@ def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train,
114114
}
115115
with pytest.raises(AttributeError):
116116
main(**kwargs)
117+
118+
@pytest.mark.skip_missing_tokenizer
119+
@patch('llama_recipes.finetuning.AutoTokenizer')
120+
def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version):
121+
monkeypatch.syspath_prepend("recipes/quickstart/finetuning/datasets/")
122+
from custom_dataset import tokenize_dialog
123+
124+
setup_tokenizer(tokenizer)
125+
tokenizer = tokenizer.from_pretrained()
126+
127+
dialog = [
128+
{"role":"user", "content":"Who made Berlin?"},
129+
{"role":"assistant", "content":"dunno"},
130+
{"role":"user", "content":"And Rome?"},
131+
{"role":"assistant", "content":"Romans"},
132+
]
133+
134+
result = tokenize_dialog(dialog, tokenizer)
135+
print(f"{tokenizer.encode('system')=}")
136+
print(f"{tokenizer.encode('user')=}")
137+
print(f"{tokenizer.encode('assistant')=}")
138+
print(f"{tokenizer.decode(result['input_ids'])=}")
139+
print(f"{result['labels']=}")
140+
141+
if "Llama-2" in llama_version:
142+
assert result["labels"][:12] == [-100] * 12
143+
assert result["labels"][17:28] == [-100] * 11
144+
assert result["labels"].count(-100) == 11 + 12
145+
else:
146+
assert result["labels"][:35] == [-100] * 35
147+
assert result["labels"][42:51] == [-100] * 9
148+
assert result["labels"].count(-100) == 35 + 9

0 commit comments

Comments
 (0)