@@ -23,7 +23,7 @@ def tokenize_dialogs(dialogs, images, processor):
23
23
text_prompt = processor .apply_chat_template (dialogs )
24
24
#print("text_prompt",text_prompt)
25
25
batch = processor (images = images , text = text_prompt ,padding = True , return_tensors = "pt" )
26
- batch [ "labels" ] = copy . copy ( batch [ "input_ids" ])
26
+ label_list = []
27
27
for i in range (len (batch ["input_ids" ])):
28
28
dialog_tokens = batch ["input_ids" ][i ].tolist ()
29
29
labels = copy .copy (dialog_tokens )
@@ -42,14 +42,62 @@ def tokenize_dialogs(dialogs, images, processor):
42
42
# Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
43
43
assistant_header_seq = [128006 , 78191 , 128007 ]
44
44
labels = replace_target (assistant_header_seq ,labels )
45
- batch ["labels" ][i ] = torch .tensor (labels )
45
+ label_list .append (labels )
46
+ batch ["labels" ] = torch .tensor (label_list )
47
+ tokenizer_length = len (processor .tokenizer )
46
48
return batch
47
49
50
+ def tokenize_dialog (dialog , images , processor ):
51
+ # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
52
+ text_prompt = processor .apply_chat_template (dialog )
53
+ #print("text_prompt",text_prompt)
54
+ batch = processor (images = images , text = text_prompt ,padding = True , return_tensors = "pt" )
55
+ labels = copy .copy (batch ["input_ids" ].tolist ()[0 ])
56
+ eot_indices = [i for i ,n in enumerate (labels ) if n == 128009 ]
57
+ last_idx = 0
58
+ # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
59
+ # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
60
+ prompt_header_seqs = [[128006 , 9125 , 128007 ],[128006 , 882 , 128007 ]]
61
+ for n , idx in enumerate (eot_indices ):
62
+ current_seq = labels [last_idx :idx + 1 ]
63
+ if check_header (prompt_header_seqs ,current_seq ):
64
+ # found prompt header, indicating that this seq should be masked
65
+ labels [last_idx :idx + 1 ] = [- 100 ] * (idx - last_idx + 1 )
66
+ else :
67
+ last_idx = idx + 1
68
+ # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
69
+ assistant_header_seq = [128006 , 78191 , 128007 ]
70
+ labels = replace_target (assistant_header_seq ,labels )
71
+ #print("labels",labels)
72
+ # print("pixel_values .shape",batch["pixel_values"].shape)
73
+ # print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
74
+
75
+ batch ["labels" ] = torch .tensor (labels )
76
+ # exit()
77
+ # combined_tokens = {
78
+ # # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
79
+ # # "labels": list(itertools.chain(*(t for t in labels_tokens))),
80
+ # "input_ids": dialog_tokens,
81
+ # "labels": labels,
82
+ # "attention_mask": [1]*len(dialog_tokens),
83
+ # "pixel_values": batch["pixel_values"],
84
+ # "aspect_ratio_ids": batch["aspect_ratio_ids"],
85
+ # "aspect_ratio_mask": batch["aspect_ratio_mask"],
86
+ # "cross_attention_mask": batch["cross_attention_mask"]
87
+ # }
88
+ # input_ids = list(itertools.chain(*(t for t in dialog_tokens))),
89
+ # labels = list(itertools.chain(*(t for t in labels_tokens))),
90
+ # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
91
+ # pixel_values = batch["pixel_values"],
92
+ # image_sizes = batch["image_sizes"]
93
+ # print("combined_tokens",combined_tokens[image_sizes])
94
+
95
+ return batch
48
96
def get_custom_dataset (dataset_config , processor , split , split_ratio = 0.9 ):
49
97
# load_dataset will return DatasetDict that contains all the data in the train set
50
98
dataset_dict = load_dataset ("remyxai/vqasynth_spacellava" )
51
99
dataset = dataset_dict [split ]
52
- dataset = dataset .select (range (100 ))
100
+ dataset = dataset .select (range (500 ))
53
101
return dataset
54
102
55
103
class VQADataCollator :
@@ -74,5 +122,20 @@ def __call__(self, samples):
74
122
dialogs .append (dialog )
75
123
images .append (image )
76
124
return tokenize_dialogs (dialogs ,images , self .processor )
125
+ def __callworking__ (self , samples ):
126
+ for sample in samples :
127
+ image ,sample_text = sample ["images" ],sample ["messages" ]
128
+ dialog = []
129
+ for line in sample_text :
130
+ content = []
131
+ messages = line ["content" ]
132
+ role = line ["role" ]
133
+ for message in messages :
134
+ if message ["type" ] == "image" :
135
+ content .append ({"type" : "image" })
136
+ elif message ["type" ] == "text" :
137
+ content .append ({"type" : "text" , "text" : message ["text" ].strip ()})
138
+ dialog .append ({"role" : role ,"content" :content })
139
+ return tokenize_dialog (dialog ,image , self .processor )
77
140
def get_data_collator (processor ):
78
141
return VQADataCollator (processor )
0 commit comments