3
3
4
4
5
5
import copy
6
- import datasets
7
- from datasets import Dataset , load_dataset , DatasetDict
6
+ from datasets import load_dataset
8
7
import itertools
9
8
10
9
B_INST , E_INST = "[INST]" , "[/INST]"
@@ -26,8 +25,6 @@ def tokenize_dialog(dialog, tokenizer):
26
25
eot_indices = [i for i ,n in enumerate (dialog_tokens ) if n == 128009 ]
27
26
labels = copy .copy (dialog_tokens )
28
27
last_idx = 0
29
- token_length = len (dialog_tokens )
30
- last_idx = 0
31
28
# system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
32
29
# user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
33
30
prompt_header_seqs = [[128006 , 9125 , 128007 ],[128006 , 882 , 128007 ]]
@@ -44,18 +41,7 @@ def tokenize_dialog(dialog, tokenizer):
44
41
dialog_tokens = [dialog_tokens ]
45
42
labels_tokens = [labels ]
46
43
else :
47
- # Otherwise, use the original tokenizer to generate the tokens as it is from Llama 2 family models
48
- prompt_tokens = [tokenizer .encode (f"{ tokenizer .bos_token } { B_INST } { (prompt ['content' ]).strip ()} { E_INST } " , add_special_tokens = False ) for prompt in dialog [:2 ]]
49
- answer = dialog [- 1 ]
50
- answer_tokens = tokenizer .encode (f"{ answer ['content' ].strip ()} { tokenizer .eos_token } " , add_special_tokens = False )
51
-
52
- #Add labels, convert prompt token to -100 in order to ignore in loss function
53
- sample = {
54
- "input_ids" : prompt_tokens + answer_tokens ,
55
- "attention_mask" : [1 ] * (len (prompt_tokens ) + len (answer_tokens )),
56
- "labels" : [- 100 ] * len (prompt_tokens ) + answer_tokens ,
57
- }
58
- return sample
44
+ raise Exception ("This raft_dataset only supports Llama 3 family models, please make sure the tokenizer is from Llama 3 family models." )
59
45
60
46
combined_tokens = {
61
47
"input_ids" : list (itertools .chain (* (t for t in dialog_tokens ))),
0 commit comments