Skip to content

Commit dadadf8

Browse files
committed
Fix custom dataset after default system prompt was added
1 parent 3e39ed0 commit dadadf8

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

recipes/quickstart/finetuning/datasets/custom_dataset.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,22 @@
99

1010

1111
B_INST, E_INST = "[INST]", "[/INST]"
12+
EOT_ID = 128009 #<|eot_id|>
1213

1314
def tokenize_dialog(dialog, tokenizer):
1415
if tokenizer.vocab_size >= 128000:
1516
dialog_tokens = tokenizer.apply_chat_template(dialog)
16-
dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
17-
eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
17+
eot_indices = [i for i,n in enumerate(dialog_tokens) if n == EOT_ID]
1818
labels = copy.copy(dialog_tokens)
19+
#determine token for system and user
20+
system_or_user = (tokenizer.encode("system")[-1], tokenizer.encode("user")[-1])
1921
last_idx = 0
2022
for n, idx in enumerate(eot_indices):
21-
if n % 2 == 1:
22-
last_idx = idx
23-
else:
23+
role_token = labels[last_idx:idx+1][2]
24+
if role_token in system_or_user:
25+
# Set labels to -100 for system and user tokens to ignore in loss function
2426
labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
27+
last_idx = idx
2528

2629
dialog_tokens = [dialog_tokens]
2730
labels_tokens = [labels]

0 commit comments

Comments
 (0)