Skip to content

Commit 3c9f263

Browse files
committed
Mask out assistant header
1 parent a38a80a commit 3c9f263

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

recipes/quickstart/finetuning/datasets/custom_dataset.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
B_INST, E_INST = "[INST]", "[/INST]"
1212
EOT_ID = 128009 #<|eot_id|>
1313

14+
def mask_target(target,seq):
15+
for i in range(len(seq)-len(target)):
16+
if seq[i:i+len(target)] == target:
17+
seq[i:i+len(target)] = [-100] * len(target)
18+
return seq
19+
1420
def tokenize_dialog(dialog, tokenizer):
1521
if tokenizer.vocab_size >= 128000:
1622
dialog_tokens = tokenizer.apply_chat_template(dialog)
@@ -25,6 +31,7 @@ def tokenize_dialog(dialog, tokenizer):
2531
# Set labels to -100 for system and user tokens to ignore in loss function
2632
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
2733
last_idx = idx
34+
mask_target(tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False), labels)
2835

2936
dialog_tokens = [dialog_tokens]
3037
labels_tokens = [labels]

src/tests/datasets/test_custom_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,6 @@ def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version)
138138
assert result["labels"][17:28] == [-100] * 11
139139
assert result["labels"].count(-100) == 11 + 12
140140
else:
141-
assert result["labels"][:35] == [-100] * 35
142-
assert result["labels"][42:51] == [-100] * 9
143-
assert result["labels"].count(-100) == 35 + 9
141+
assert result["labels"][:38] == [-100] * 38
142+
assert result["labels"][42:54] == [-100] * 12
143+
assert result["labels"].count(-100) == 38 + 12

0 commit comments

Comments
 (0)