@@ -395,7 +395,9 @@ def __init__(
395395 rcc_params = (224 ,),
396396 sparse_sample = False ,
397397 labels = None ,
398- is_trimmed = True ):
398+ is_trimmed = True ,
399+ topk_predictions = 5
400+ ):
399401 super ().__init__ (dataset , root , metadata , is_trimmed = is_trimmed )
400402
401403 self .transform = transform
@@ -416,6 +418,7 @@ def __init__(
416418 for noun in self .nouns :
417419 self .valid_gts .append (f'{ verb } { noun } ' )
418420 self .labels = labels
421+ self .topk_predictions = topk_predictions
419422
420423 def __getitem__ (self , i ):
421424 frames , label = self .get_raw_item (
@@ -440,21 +443,25 @@ def __getitem__(self, i):
440443
441444 verb , noun = self .verbs [int (verb )], self .nouns [int (noun )]
442445
443- letters = ['A' , 'B' , 'C' , 'D' , 'E' ]
444- options = [0 ,1 ,2 ,3 ,4 ]
446+ letters = [chr (65 + i ) for i in range (26 )][:self .topk_predictions ]
447+ options = list (range (26 ))[:self .topk_predictions ]
448+
449+
445450 wrong_answer_indices = np .random .choice (len (self .valid_gts ), size = 5 , replace = False )
446451 wrong_answers = [self .valid_gts [index ] for index in wrong_answer_indices ]
447452 for i in range (len (wrong_answers )):
448453 options [i ] = f'{ letters [i ]} . { wrong_answers [i ]} '
449- correct_answer_index = np .random .choice (5 , size = 1 , replace = False )[0 ]
454+
455+ correct_answer_index = np .random .choice (len (letters ), size = 1 , replace = False )[0 ]
450456 correct_answer_letter = letters [correct_answer_index ]
451457
452458 options [correct_answer_index ] = f'{ correct_answer_letter } . { verb } { noun } '
453459
454460 data = {
455461 'question' : {0 : 'the video is an egocentric view of a person. What is the person doing? Pick the the letter that has the correct answer' },
456462 'option' : {0 : options },
457- 'answer' : {0 : correct_answer_letter }
463+ 'answer' : {0 : correct_answer_letter },
464+ 'answer_name' : {0 : f'{ verb } { noun } ' }
458465 }
459466
460467 return frames , data
@@ -481,7 +488,8 @@ def get_downstream_dataset(transform, crop_size, args, subset='train', label_map
481488 threads = args .decode_threads ,
482489 fast_rcc = args .fused_decode_crop , rcc_params = (crop_size , ),
483490 is_trimmed = not args .dataset == 'charades_ego' ,
484- labels = labels
491+ labels = labels ,
492+ topk_predictions = args .topk_predictions
485493 )
486494 else :
487495 assert ValueError ("subset should be either 'train' or 'val'" )
@@ -677,7 +685,7 @@ def get_topk_predictions(prediction_file, idx, k):
677685 total_samples = 0
678686
679687 if args .action_predictions :
680- valid_letters = [chr (65 + i ) for i in range (26 )][args .topk_predictions ]
688+ valid_letters = [chr (65 + i ) for i in range (26 )][: args .topk_predictions ]
681689 else :
682690 valid_letters = ['A' , 'B' , 'C' , 'D' , 'E' ]
683691
@@ -712,7 +720,8 @@ def get_topk_predictions(prediction_file, idx, k):
712720
713721 if args .action_predictions :
714722 mc_data = get_topk_predictions (args .action_predictions , idx , args .topk_predictions )
715-
723+
724+
716725 pred = llava_inference (frames , tokenizer , model , image_processor , max_length , mc_data , num_frames = args .llava_num_frames )
717726
718727 # if valid letter is found in the prediction, then we will use that as the prediction
@@ -726,6 +735,7 @@ def get_topk_predictions(prediction_file, idx, k):
726735 if not found :
727736 pred = 'N/A'
728737
738+
729739 preds .append (pred )
730740
731741 # Update running corrects and total samples
0 commit comments