Skip to content

Commit d9939cd

Browse files
committed
Fix to take bos into account
1 parent 3c9f263 commit d9939cd

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

recipes/quickstart/finetuning/datasets/custom_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@ def tokenize_dialog(dialog, tokenizer):
2424
labels = copy.copy(dialog_tokens)
2525
#determine token for system and user
2626
system_or_user = (tokenizer.encode("system")[-1], tokenizer.encode("user")[-1])
27-
last_idx = 0
27+
labels[0] = -100 # bos token
28+
last_idx = 1
2829
for n, idx in enumerate(eot_indices):
29-
role_token = labels[last_idx:idx+1][2]
30+
role_token = labels[last_idx+1]
3031
if role_token in system_or_user:
3132
# Set labels to -100 for system and user tokens to ignore in loss function
3233
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
33-
last_idx = idx
34+
last_idx = idx + 1
3435
mask_target(tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False), labels)
3536

3637
dialog_tokens = [dialog_tokens]

src/tests/datasets/test_custom_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,5 @@ def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version)
139139
assert result["labels"].count(-100) == 11 + 12
140140
else:
141141
assert result["labels"][:38] == [-100] * 38
142-
assert result["labels"][42:54] == [-100] * 12
143-
assert result["labels"].count(-100) == 38 + 12
142+
assert result["labels"][43:54] == [-100] * 11
143+
assert result["labels"].count(-100) == 38 + 11

0 commit comments

Comments
 (0)