1515import evaluate
1616from sklearn import preprocessing
1717import numpy as np
18+ import sys
1819
1920from datasets import load_dataset_builder
2021from datasets import load_dataset
2122
2223class VQATextDataset (Dataset ):
23- def __init__ (self , df , split , transforms , labelencoder , tokenizer = None ):
24+ def __init__ (self , df , split , transforms , answer_set , tokenizer = None ):
2425 self .df = df
2526 self .transforms = transforms
2627 self .tokenize = tokenizer
27- self .labels = labelencoder . transform ( df [ 'multiple_choice_answer' ] )
28+ self .num_classes = len ( answer_set )
2829 def __len__ (self ):
2930 return len (self .df )
3031
@@ -34,13 +35,20 @@ def __getitem__(self, idx):
3435 image = Image .open (str (img_path ))
3536 text = item ["question" ]
3637 label = self .labels [idx ]
38+ target = np .zeros (self .num_classes )
39+ for i in range (df ['answer_list' ]):
40+ target [df ['answer_list' ][i ]] = df ['answer_weights' ][i ]
41+
3742 return {
3843 'image' : self .transforms (image ),
3944 'text' : self .tokenize ([text ])[0 ],
40- 'label ' : torch .tensor (label )
45+ 'target ' : torch .tensor (target )
4146 }
4247
43- def get_task_dataloaders (path , transforms , labelencoder , args ):
48+ def get_score (count : int ) -> float :
49+ return min (1.0 , count / 3 )
50+
51+ def get_task_dataloaders (path , transforms , labelencoder , answer_set , args ):
4452 tokenizer = get_tokenizer (args .model )
4553 dataloaders = {}
4654
@@ -52,29 +60,43 @@ def get_task_dataloaders(path, transforms, labelencoder, args):
5260 questions = []
5361 images = []
5462 answers = []
63+ weights = []
5564 for index , row in dataset_df .iterrows ():
56- if (row ['multiple_choice_answer' ] in answer_set ):
65+ answer_count = {}
66+ for answer in row ['answers' ]:
67+ answer_ = answer ["answer" ]
68+ answer_count [answer_ ] = answer_count .get (answer_ , 0 ) + 1
69+ labels = []
70+ scores = []
71+ for answer in answer_count :
72+ if answer not in answer_set :
73+ continue
74+ labels .append (labelencoder .transform ([answer ])[0 ])
75+ score = get_score (answer_count [answer ])
76+ scores .append (score )
77+ if (len (labels ) == 0 ):
78+ continue
5779 class_id .append (row ['question_id' ])
5880 questions .append (row ['question' ])
5981 images .append (row ['image' ])
60- answers .append (row ['multiple_choice_answer' ])
82+ answers .append (labels )
83+ weights .append (scores )
84+
6185 class_id = np .array (class_id )
6286 questions = np .array (questions )
6387 images = np .array (images )
64- answers = np .array (answers )
65-
66- dataset_df = pd .DataFrame ({'question_id' : class_id , 'question' : questions , 'image' : images , 'multiple_choice_answer' : answers })
88+ dataset_df = pd .DataFrame ({'question_id' : class_id , 'question' : questions , 'image' : images , 'answer_list' : answers , 'answer_weights' : weights })
6789 #dataset_df = dataset_df[0:12800]
6890 b_size = args .batch_size
6991 if (split == "validation" ):
7092 b_size = args .batch_size * 20
7193 dataset_df = dataset_df [0 :12800 ]
72- dataset = VQATextDataset (dataset_df ,
73- split ,
74- transforms ,
75- labelencoder ,
76- tokenizer = tokenizer ,
77- )
94+ dataset = VQATextDataset (dataset_df ,
95+ split ,
96+ transforms ,
97+ answer_set ,
98+ tokenizer = tokenizer ,
99+ )
78100 dataloader = DataLoader (
79101 dataset ,
80102 batch_size = b_size ,
@@ -95,7 +117,7 @@ def __init__(self, encoder, embed_dim, num_labels):
95117
96118 self .fc1 = nn .Linear (embed_dim * 2 , 1536 ) #size of answer space
97119 self .lnorm = nn .LayerNorm (1536 )
98- self .fc2 = nn .Linear (1536 , num_classes )
120+ self .fc2 = nn .Linear (1536 , num_labels )
99121 def forward (self , image , text ):
100122 # CLIP doesn't have a multimodal encoder, so we concatenate the features
101123 text_features = self .encoder .encode_text (text )
@@ -136,16 +158,15 @@ def compute_metrics(model, dataloader, device, args):
136158 metric = evaluate .load ("accuracy" )
137159 val_loss = 0
138160 samples_seen = 0
139- loss_fn = nn .CrossEntropyLoss ()
140161 for batch in dataloader :
141162 with torch .no_grad ():
142163 image = batch ["image" ].to (device )
143164 text = batch ["text" ].to (device )
144- label = batch ["label " ].to (device )
165+ label = batch ["target " ].to (device )
145166 samples_seen += text .shape [0 ]
146167 logits = model (image , text )
147168 predictions = torch .argmax (logits , dim = - 1 )
148- batch_val_loss = loss_fn (logits , label )
169+ batch_val_loss = nn . functional . binary_cross_entropy_with_logits (logits , label , reduction = "sum" )
149170 val_loss += batch_val_loss .item ()
150171 print (val_loss )
151172 metric .add_batch (
@@ -164,31 +185,29 @@ def train_single_epoch(model, data, optimizer, args):
164185 for i , batch in enumerate (data ["train" ]):
165186 image = batch ["image" ].to (device )
166187 text = batch ["text" ].to (device )
167- label = batch ["label " ].to (device )
188+ label = batch ["target " ].to (device )
168189
169190 logits = model (image , text )
170191 print (label .shape )
171192 print (logits .shape )
172- loss_fn = nn .CrossEntropyLoss ()
173- loss = loss_fn (logits , label )
193+ loss = nn .functional .binary_cross_entropy_with_logits (logits , label , reduction = "sum" )
174194 print (loss )
175195 loss .backward ()
176196
177197
178198def train_one_epoch (model , data , epoch , optimizer , scheduler , early_stop , device , args ):
179199 model .train ()
180- loss_fn = nn .CrossEntropyLoss ()
181200 progress_bar = tqdm (total = len (data ["train" ]))
182201 for i , batch in enumerate (data ["train" ]):
183202 step = epoch * len (data ["train" ]) + i
184203 scheduler (step )
185204
186205 image = batch ["image" ].to (device )
187206 text = batch ["text" ].to (device )
188- label = batch ["label " ].to (device )
207+ label = batch ["target " ].to (device )
189208 logits = model (image , text )
190209
191- loss = loss_fn (logits , label ) #should be cross entropy
210+ loss = nn . functional . binary_cross_entropy_with_logits (logits , label , reduction = "sum" ) #should be cross entropy
192211
193212 optimizer .zero_grad ()
194213 loss .backward ()
@@ -228,7 +247,7 @@ def parse_args(args):
228247 parser .add_argument (
229248 "--epochs" , type = int , default = 10 , help = "Number of epochs to train for."
230249 )
231- parser .add_argument ("--lr" , type = float , default = 3e-4 , help = "Learning rate." )
250+ parser .add_argument ("--lr" , type = float , default = 1e-5 , help = "Learning rate." )
232251 parser .add_argument ("--beta1" , type = float , default = 0.9 , help = "Adam beta 1." )
233252 parser .add_argument ("--beta2" , type = float , default = 0.999 , help = "Adam beta 2." )
234253 parser .add_argument ("--eps" , type = float , default = 1e-8 , help = "Adam epsilon." )
@@ -273,8 +292,8 @@ def parse_args(args):
273292 args = parser .parse_args (args )
274293 return args
275294
276- if __name__ == "__main__" :
277- args = parse_args ([] )
295+ def main ( args ) :
296+ args = parse_args (args )
278297 device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
279298
280299 model , preprocess_train , preprocess_val = open_clip .factory .create_model_and_transforms (
@@ -287,7 +306,7 @@ def parse_args(args):
287306 embed_dim = model_cfg ["embed_dim" ]
288307
289308 answer_space = []
290- with open ('answers_vqa.txt' ) as f :
309+ with open ('src/training/ answers_vqa.txt' ) as f :
291310 for line in f :
292311 answer_space .append (line .strip ())
293312 answer_space = np .array (answer_space )
@@ -298,7 +317,7 @@ def parse_args(args):
298317
299318 answer_set = set (labelencoder .classes_ )
300319
301- data = get_task_dataloaders ("HuggingFaceM4/VQAv2" , preprocess_val , labelencoder , args )
320+ data = get_task_dataloaders ("HuggingFaceM4/VQAv2" , preprocess_val , labelencoder , answer_set , args )
302321
303322 clf_cls = CLIPMultimodalClassifier
304323 clf = clf_cls (model , embed_dim , num_classes ).to (device )
@@ -314,3 +333,6 @@ def parse_args(args):
314333
315334 for epoch in range (20 ):
316335 val_metrics , end_training = train_one_epoch (clf , data , epoch , optim , scheduler , early_stop , device , args )
336+
337+ if __name__ == "__main__" :
338+ main (sys .argv [1 :])
0 commit comments