|
11 | 11 | "example_1": "[INST] Who made Berlin [/INST] dunno",
|
12 | 12 | "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
|
13 | 13 | },
|
14 |
| - "meta-llama/Meta-Llama-3.1-8B":{ |
| 14 | + "meta-llama/Meta-Llama-3.1-8B-Instruct":{ |
15 | 15 | "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
|
16 | 16 | "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
|
17 | 17 | },
|
@@ -114,3 +114,35 @@ def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train,
|
114 | 114 | }
|
115 | 115 | with pytest.raises(AttributeError):
|
116 | 116 | main(**kwargs)
|
| 117 | + |
| 118 | +@pytest.mark.skip_missing_tokenizer |
| 119 | +@patch('llama_recipes.finetuning.AutoTokenizer') |
| 120 | +def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version): |
| 121 | + monkeypatch.syspath_prepend("recipes/quickstart/finetuning/datasets/") |
| 122 | + from custom_dataset import tokenize_dialog |
| 123 | + |
| 124 | + setup_tokenizer(tokenizer) |
| 125 | + tokenizer = tokenizer.from_pretrained() |
| 126 | + |
| 127 | + dialog = [ |
| 128 | + {"role":"user", "content":"Who made Berlin?"}, |
| 129 | + {"role":"assistant", "content":"dunno"}, |
| 130 | + {"role":"user", "content":"And Rome?"}, |
| 131 | + {"role":"assistant", "content":"Romans"}, |
| 132 | + ] |
| 133 | + |
| 134 | + result = tokenize_dialog(dialog, tokenizer) |
| 135 | + print(f"{tokenizer.encode('system')=}") |
| 136 | + print(f"{tokenizer.encode('user')=}") |
| 137 | + print(f"{tokenizer.encode('assistant')=}") |
| 138 | + print(f"{tokenizer.decode(result['input_ids'])=}") |
| 139 | + print(f"{result['labels']=}") |
| 140 | + |
| 141 | + if "Llama-2" in llama_version: |
| 142 | + assert result["labels"][:12] == [-100] * 12 |
| 143 | + assert result["labels"][17:28] == [-100] * 11 |
| 144 | + assert result["labels"].count(-100) == 11 + 12 |
| 145 | + else: |
| 146 | + assert result["labels"][:35] == [-100] * 35 |
| 147 | + assert result["labels"][42:51] == [-100] * 9 |
| 148 | + assert result["labels"].count(-100) == 35 + 9 |
0 commit comments