@@ -440,22 +440,24 @@ def __getitem__(self, i):
440440 frames = self .transform (frames )
441441
442442 verb , noun = label .split (':' )
443-
444443 verb , noun = self .verbs [int (verb )], self .nouns [int (noun )]
445444
446445 letters = [chr (65 + i ) for i in range (26 )][:self .topk_predictions ]
447446 options = list (range (26 ))[:self .topk_predictions ]
448-
447+ option_names = []
449448 # wrong answer can come from any valid gt
450449 wrong_answer_indices = np .random .choice (len (self .valid_gts ), size = 5 , replace = False )
451450 wrong_answers = [self .valid_gts [index ] for index in wrong_answer_indices ]
452451 for i in range (len (wrong_answers )):
453452 options [i ] = f'{ letters [i ]} . { wrong_answers [i ]} '
453+ option_names .append (wrong_answers [i ])
454454
455455 # correct answer must come from the available letters
456456 correct_answer_index = np .random .choice (len (letters ), size = 1 , replace = False )[0 ]
457457 correct_answer_letter = letters [correct_answer_index ]
458458
459+ option_names [correct_answer_index ] = f'{ verb } { noun } '
460+
459461 options [correct_answer_index ] = f'{ correct_answer_letter } . { verb } { noun } '
460462
461463 data = {
@@ -467,7 +469,7 @@ def __getitem__(self, i):
467469 'answer_name' : {0 : f'{ verb } { noun } ' }
468470 }
469471
470- return frames , data
472+ return frames , data , option_names
471473
472474
473475def get_downstream_dataset (transform , crop_size , args , subset = 'train' , label_mapping = None , labels = None ):
@@ -644,18 +646,19 @@ def get_topk_predictions(data, idx, k):
644646
645647 letters = [chr (65 + i ) for i in range (26 )][:k ]
646648 options = list (range (26 ))[:k ]
647-
648- predictions = data [str (idx )]['predictions' ][:k ]
649649
650+ predictions = data [str (idx )]['predictions' ][:k ]
651+ predictions = [e .replace (':' , ' ' ) for e in predictions ]
652+
650653 for i in range (len (options )):
651654 options [i ] = f'{ letters [i ]} . { predictions [i ]} '
652-
653-
655+
654656 mc_data = {
655657 'question' : {0 : 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer' },
656658 'option' : {0 : options }
657- }
658- return mc_data
659+ }
660+
661+ return mc_data , predictions
659662
660663
661664if __name__ == '__main__' :
@@ -712,21 +715,29 @@ def get_topk_predictions(data, idx, k):
712715
713716 tokenizer , model , image_processor , max_length = prepare_llava ()
714717
715- with open (args .action_predictions , 'r' ) as f :
716- predictions = json .load (f )
718+ if args .action_predictions :
719+ with open (args .action_predictions , 'r' ) as f :
720+ predictions = json .load (f )
721+
722+ avaion_correct = 0
717723
718- for idx , (frames , mc_data ) in tqdm (enumerate (val_dataloader )):
724+ for idx , (frames , mc_data , option_names ) in tqdm (enumerate (val_dataloader )):
719725
720726 gt = mc_data ['answer' ][0 ][0 ]
727+
728+ gt_name = mc_data ['answer_name' ][0 ][0 ]
721729
722- gts .append (gt )
730+ gts .append (gt_name )
723731
724732 if args .action_predictions :
725- mc_data = get_topk_predictions (predictions , idx , args .topk_predictions )
726-
733+ mc_data , avaion_pred = get_topk_predictions (predictions , idx , args .topk_predictions )
734+ if gt_name == avaion_pred [0 ]:
735+ avaion_correct += 1
736+ else :
737+ pass
727738
728739 pred = llava_inference (frames , tokenizer , model , image_processor , max_length , mc_data , num_frames = args .llava_num_frames )
729-
740+
730741 # if valid letter is found in the prediction, then we will use that as the prediction
731742 found = False
732743
@@ -738,17 +749,38 @@ def get_topk_predictions(data, idx, k):
738749 if not found :
739750 pred = 'N/A'
740751
741-
742- preds .append (pred )
752+ if args .action_predictions :
753+ if pred in valid_letters :
754+ pred_index = valid_letters .index (pred )
755+ pred_name = avaion_pred [pred_index ]
756+ else :
757+ pred_name = 'N/A'
758+ else :
759+ if pred in valid_letters :
760+ pred_index = valid_letters .index (pred )
761+ pred_name = option_names [pred_index ][0 ]
762+ else :
763+ pred_name = 'N/A'
764+
765+ print ('gt_name' , gt_name )
766+ print ('pred_name' , pred_name )
767+
768+ preds .append (pred_name )
743769
744770 # Update running corrects and total samples
745- running_corrects += (pred == gt )
771+ running_corrects += (pred_name == gt_name )
746772 total_samples += 1
747773
748774 # Calculate and log running mean accuracy
749775 running_accuracy = running_corrects / total_samples
776+
750777 logger .info (f'Running accuracy after { total_samples } samples: { running_accuracy :.4f} ' )
751778
779+ if args .action_predictions :
780+ avaion_accuracy = avaion_correct / total_samples
781+ logger .info (f'Running avaion accuracy after { total_samples } samples: { avaion_accuracy :.4f} ' )
782+
783+
752784 gts = np .array (gts )
753785 preds = np .array (preds )
754786 # get final accuracy
0 commit comments