Skip to content

Commit 8394927

Browse files
committed
raft_dataset.py must be used with llama3 tokenizer
1 parent 890d49d commit 8394927

File tree

1 file changed

+2
-16
lines changed

1 file changed

+2
-16
lines changed

recipes/finetuning/datasets/raft_dataset.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44

55
import copy
6-
import datasets
7-
from datasets import Dataset, load_dataset, DatasetDict
6+
from datasets import load_dataset
87
import itertools
98

109
B_INST, E_INST = "[INST]", "[/INST]"
@@ -26,8 +25,6 @@ def tokenize_dialog(dialog, tokenizer):
2625
eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
2726
labels = copy.copy(dialog_tokens)
2827
last_idx = 0
29-
token_length = len(dialog_tokens)
30-
last_idx = 0
3128
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
3229
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
3330
prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
@@ -44,18 +41,7 @@ def tokenize_dialog(dialog, tokenizer):
4441
dialog_tokens = [dialog_tokens]
4542
labels_tokens = [labels]
4643
else:
47-
# Otherwise, use the original tokenizer to generate the tokens as it is from Llama 2 family models
48-
prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[:2]]
49-
answer = dialog[-1]
50-
answer_tokens = tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False)
51-
52-
#Add labels, convert prompt token to -100 in order to ignore in loss function
53-
sample = {
54-
"input_ids": prompt_tokens + answer_tokens,
55-
"attention_mask" : [1] * (len(prompt_tokens) + len(answer_tokens)),
56-
"labels": [-100] * len(prompt_tokens) + answer_tokens,
57-
}
58-
return sample
44+
raise Exception("This raft_dataset only supports Llama 3 family models, please make sure the tokenizer is from Llama 3 family models.")
5945

6046
combined_tokens = {
6147
"input_ids": list(itertools.chain(*(t for t in dialog_tokens))),

0 commit comments

Comments
 (0)