1515import os
1616sys .path [0 ] = os .path .dirname (sys .path [0 ])
1717from action .llava_ov_inference import llava_inference
18+ import json
1819import logging
1920
2021
@@ -607,6 +608,9 @@ def get_args_parser():
607608 # llm size is type of string and can only be '7b' or '5b' etc.
608609 parser .add_argument ('--llm_size' , default = '7b' , type = str , help = 'llm size' )
609610 parser .add_argument ('--llava_num_frames' , default = 16 , type = int , help = 'number of frames for llava' )
611+ ## avaion refinement
612+ parser .add_argument ('--action_predictions' , default = None , type = str , help = 'path to action predictions' )
613+ parser .add_argument ('--topk_predictions' , default = 5 , type = int )
610614
611615 return parser
612616
@@ -625,6 +629,27 @@ def prepare_llava():
625629 return tokenizer , model , image_processor , max_length
626630
627631
632+ def get_topk_predictions (prediction_file , idx , k ):
633+
634+ with open (prediction_file , 'r' ) as f :
635+ data = json .load (f )
636+
637+ letters = [chr (65 + i ) for i in range (26 )][:k ]
638+ options = list (range (26 ))[:k ]
639+
640+ predictions = data [str (idx )]['predictions' ][:k ]
641+
642+ for i in range (len (options )):
643+ options [i ] = f'{ letters [i ]} . { predictions [i ]} '
644+
645+
646+ mc_data = {
647+ 'question' : {0 : 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer' },
648+ 'option' : {0 : options }
649+ }
650+ return mc_data
651+
652+
628653if __name__ == '__main__' :
629654 from moviepy .editor import ImageSequenceClip
630655 import torchvision
@@ -651,10 +676,18 @@ def prepare_llava():
651676 running_corrects = 0
652677 total_samples = 0
653678
654- valid_letters = ['A' , 'B' , 'C' , 'D' , 'E' ]
679+ if args .action_predictions :
680+ valid_letters = [chr (65 + i ) for i in range (26 )][args .topk_predictions ]
681+ else :
682+ valid_letters = ['A' , 'B' , 'C' , 'D' , 'E' ]
655683
684+ if not args .action_predictions :
685+ log_filename = f'llava_ov_{ args .llava_num_frames } f_{ args .llm_size } .log'
686+ else :
687+ log_filename = f'llava_ov_{ args .llava_num_frames } f_{ args .llm_size } _action_{ args .topk_predictions } .log'
688+
656689 # Set up logging
657- logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' , filename = f'llava_ov_ { args . llava_num_frames } f_ { args . llm_size } .log' , filemode = 'w' )
690+ logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' , filename = log_filename , filemode = 'w' )
658691
659692 console_handler = logging .StreamHandler (sys .stdout )
660693 console_handler .setLevel (logging .INFO )
@@ -671,11 +704,20 @@ def prepare_llava():
671704
672705 tokenizer , model , image_processor , max_length = prepare_llava ()
673706
674- for idx , (frames , gt ) in tqdm (enumerate (val_dataloader )):
675- pred = llava_inference (frames , tokenizer , model , image_processor , max_length , gt , num_frames = args .llava_num_frames )
707+ for idx , (frames , mc_data ) in tqdm (enumerate (val_dataloader )):
708+
709+ gt = mc_data ['answer' ][0 ][0 ]
710+
711+ gts .append (gt )
712+
713+ if args .action_predictions :
714+ mc_data = get_topk_predictions (args .action_predictions , idx , args .topk_predictions )
715+
716+ pred = llava_inference (frames , tokenizer , model , image_processor , max_length , mc_data , num_frames = args .llava_num_frames )
676717
677718 # if valid letter is found in the prediction, then we will use that as the prediction
678719 found = False
720+
679721 for letter in valid_letters :
680722 if letter in pred :
681723 pred = letter
@@ -684,11 +726,10 @@ def prepare_llava():
684726 if not found :
685727 pred = 'N/A'
686728
687- gts .append (gt ['answer' ][0 ][0 ])
688729 preds .append (pred )
689730
690731 # Update running corrects and total samples
691- running_corrects += (pred == gt [ 'answer' ][ 0 ][ 0 ] )
732+ running_corrects += (pred == gt )
692733 total_samples += 1
693734
694735 # Calculate and log running mean accuracy
0 commit comments