@@ -446,12 +446,13 @@ def __getitem__(self, i):
446446 letters = [chr (65 + i ) for i in range (26 )][:self .topk_predictions ]
447447 options = list (range (26 ))[:self .topk_predictions ]
448448
449-
449+ # wrong answer can come from any valid gt
450450 wrong_answer_indices = np .random .choice (len (self .valid_gts ), size = 5 , replace = False )
451451 wrong_answers = [self .valid_gts [index ] for index in wrong_answer_indices ]
452452 for i in range (len (wrong_answers )):
453453 options [i ] = f'{ letters [i ]} . { wrong_answers [i ]} '
454-
454+
455+ # correct answer must come from the available letters
455456 correct_answer_index = np .random .choice (len (letters ), size = 1 , replace = False )[0 ]
456457 correct_answer_letter = letters [correct_answer_index ]
457458
@@ -460,7 +461,9 @@ def __getitem__(self, i):
460461 data = {
461462 'question' : {0 : 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer' },
462463 'option' : {0 : options },
464+ # the correct letter in mc
463465 'answer' : {0 : correct_answer_letter },
466+ # for inspecting
464467 'answer_name' : {0 : f'{ verb } { noun } ' }
465468 }
466469
@@ -637,10 +640,7 @@ def prepare_llava():
637640 return tokenizer , model , image_processor , max_length
638641
639642
640- def get_topk_predictions (prediction_file , idx , k ):
641-
642- with open (prediction_file , 'r' ) as f :
643- data = json .load (f )
643+ def get_topk_predictions (data , idx , k ):
644644
645645 letters = [chr (65 + i ) for i in range (26 )][:k ]
646646 options = list (range (26 ))[:k ]
@@ -711,6 +711,9 @@ def get_topk_predictions(prediction_file, idx, k):
711711 pretrained = f"lmms-lab/llava-onevision-qwen2-{ args .llm_size } -ov"
712712
713713 tokenizer , model , image_processor , max_length = prepare_llava ()
714+
715+ with open (args .action_predictions , 'r' ) as f :
716+ predictions = json .load (f )
714717
715718 for idx , (frames , mc_data ) in tqdm (enumerate (val_dataloader )):
716719
@@ -719,7 +722,7 @@ def get_topk_predictions(prediction_file, idx, k):
719722 gts .append (gt )
720723
721724 if args .action_predictions :
722- mc_data = get_topk_predictions (args . action_predictions , idx , args .topk_predictions )
725+ mc_data = get_topk_predictions (predictions , idx , args .topk_predictions )
723726
724727
725728 pred = llava_inference (frames , tokenizer , model , image_processor , max_length , mc_data , num_frames = args .llava_num_frames )
0 commit comments