6
6
from datasets import load_dataset
7
7
import itertools
8
8
import torch
9
+
9
10
# check system prompt token seq or user prompt token seq is in the current token list
10
11
def check_header (targets ,seq ):
11
12
for i in range (len (seq )- 3 ):
@@ -17,78 +18,61 @@ def replace_target(target,seq):
17
18
if seq [i :i + 3 ] == target :
18
19
seq [i ],seq [i + 1 ],seq [i + 2 ] = - 100 ,- 100 ,- 100
19
20
return seq
20
- def tokenize_dialog ( dialog , images , processor ):
21
+ def tokenize_dialogs ( dialogs , images , processor ):
21
22
# If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
22
- text_prompt = processor .apply_chat_template (dialog )
23
+ text_prompt = processor .apply_chat_template (dialogs )
23
24
#print("text_prompt",text_prompt)
24
- batch = processor (images = images , text = text_prompt ,padding = True , return_tensors = "pt" )
25
- labels = copy .copy (batch ["input_ids" ].tolist ()[0 ])
26
- eot_indices = [i for i ,n in enumerate (labels ) if n == 128009 ]
27
- last_idx = 0
28
- # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
29
- # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
30
- prompt_header_seqs = [[128006 , 9125 , 128007 ],[128006 , 882 , 128007 ]]
31
- for n , idx in enumerate (eot_indices ):
32
- current_seq = labels [last_idx :idx + 1 ]
33
- if check_header (prompt_header_seqs ,current_seq ):
34
- # found prompt header, indicating that this seq should be masked
35
- labels [last_idx :idx + 1 ] = [- 100 ] * (idx - last_idx + 1 )
36
- else :
37
- last_idx = idx + 1
38
- # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
39
- assistant_header_seq = [128006 , 78191 , 128007 ]
40
- labels = replace_target (assistant_header_seq ,labels )
41
- #print("labels",labels)
42
- # print("pixel_values .shape",batch["pixel_values"].shape)
43
- # print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
44
-
45
- batch ["labels" ] = torch .tensor (labels )
46
- #pixel_values .shape torch.Size([1, 1, 4, 3, 560, 560])
47
- batch ["pixel_values" ] = torch .squeeze (batch ["pixel_values" ], 1 )
48
- # pixel_values .shape torch.Size([1, 4, 3, 560, 560])
49
- print ("pixel_values .shape" ,batch ["pixel_values" ].shape )
50
- # exit()
51
- # combined_tokens = {
52
- # # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
53
- # # "labels": list(itertools.chain(*(t for t in labels_tokens))),
54
- # "input_ids": dialog_tokens,
55
- # "labels": labels,
56
- # "attention_mask": [1]*len(dialog_tokens),
57
- # "pixel_values": batch["pixel_values"],
58
- # "aspect_ratio_ids": batch["aspect_ratio_ids"],
59
- # "aspect_ratio_mask": batch["aspect_ratio_mask"],
60
- # "cross_attention_mask": batch["cross_attention_mask"]
61
- # }
62
- # input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
63
- # labels = list(itertools.chain(*(t for t in labels_tokens))),
64
- # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
65
- # pixel_values = batch["pixel_values"],
66
- # image_sizes = batch["image_sizes"]
67
- # print("combined_tokens",combined_tokens[image_sizes])
68
-
25
+ batch = processor (images = images , text = text_prompt ,padding = True , return_tensors = "pt" )
26
+ batch ["labels" ] = copy .copy (batch ["input_ids" ])
27
+ for i in range (len (batch ["input_ids" ])):
28
+ dialog_tokens = batch ["input_ids" ][i ].tolist ()
29
+ labels = copy .copy (dialog_tokens )
30
+ eot_indices = [i for i ,n in enumerate (labels ) if n == 128009 ]
31
+ last_idx = 0
32
+ # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
33
+ # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
34
+ prompt_header_seqs = [[128006 , 9125 , 128007 ],[128006 , 882 , 128007 ]]
35
+ for n , idx in enumerate (eot_indices ):
36
+ current_seq = labels [last_idx :idx + 1 ]
37
+ if check_header (prompt_header_seqs ,current_seq ):
38
+ # found prompt header, indicating that this seq should be masked
39
+ labels [last_idx :idx + 1 ] = [- 100 ] * (idx - last_idx + 1 )
40
+ else :
41
+ last_idx = idx + 1
42
+ # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
43
+ assistant_header_seq = [128006 , 78191 , 128007 ]
44
+ labels = replace_target (assistant_header_seq ,labels )
45
+ batch ["labels" ][i ] = torch .tensor (labels )
69
46
return batch
70
- def image_tokenize (sample , processor ):
71
- processor .tokenizer .padding_side = "right" # during training, one always uses padding on the right
72
- images ,sample_text = sample ["images" ],sample ["messages" ]
73
- dialog = []
74
- for line in sample_text :
75
- content = []
76
- messages = line ["content" ]
77
- role = line ["role" ]
78
- for message in messages :
79
- if message ["type" ] == "image" :
80
- content .append ({"type" : "image" })
81
- elif message ["type" ] == "text" :
82
- content .append ({"type" : "text" , "text" : message ["text" ].strip ()})
83
- dialog .append ({"role" : role ,"content" :content })
84
- return tokenize_dialog (dialog ,images , processor )
85
-
86
47
87
48
def get_custom_dataset (dataset_config , processor , split , split_ratio = 0.9 ):
88
49
# load_dataset will return DatasetDict that contains all the data in the train set
89
50
dataset_dict = load_dataset ("remyxai/vqasynth_spacellava" )
90
51
dataset = dataset_dict [split ]
91
52
dataset = dataset .select (range (100 ))
92
- tokenized_datasets = dataset .map (lambda x : image_tokenize (x , processor ))
93
- tokenized_datasets = tokenized_datasets .remove_columns (dataset .column_names )
94
- return tokenized_datasets
53
+ return dataset
54
+
55
+ class VQADataCollator :
56
+ def __init__ (self , processor ):
57
+ self .processor = processor
58
+ self .processor .tokenizer .padding_side = "right" # during training, one always uses padding on the right
59
+ def __call__ (self , samples ):
60
+ dialogs ,images = [],[]
61
+ for sample in samples :
62
+ image ,sample_text = sample ["images" ],sample ["messages" ]
63
+ dialog = []
64
+ for line in sample_text :
65
+ content = []
66
+ messages = line ["content" ]
67
+ role = line ["role" ]
68
+ for message in messages :
69
+ if message ["type" ] == "image" :
70
+ content .append ({"type" : "image" })
71
+ elif message ["type" ] == "text" :
72
+ content .append ({"type" : "text" , "text" : message ["text" ].strip ()})
73
+ dialog .append ({"role" : role ,"content" :content })
74
+ dialogs .append (dialog )
75
+ images .append (image )
76
+ return tokenize_dialogs (dialogs ,images , self .processor )
77
+ def get_data_collator (processor ):
78
+ return VQADataCollator (processor )
0 commit comments