@@ -522,12 +522,34 @@ def train_generate(self, gt_vn, avion_predictions, narration, k, action_represen
522522 }
523523 return mc_data
524524
525- def test_generate (self , gt_vn , avion_predictions , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps ):
525+ def test_generate (self ,
526+ gt_vn ,
527+ action_model_predictions ,
528+ narration ,
529+ k ,
530+ action_representation ,
531+ n_narrations ,
532+ labels ,
533+ mapping_vn2narration ,
534+ verb_maps ,
535+ noun_maps ,
536+ benchmark_testing = False
537+ ):
526538 """
527539 During testing, we use the top k predictions from avion. No randomness. We do not mix the gt_vn with the avion predictions
528540 """
529-
530- answer_ids = avion_predictions [:k ]
541+ answer_ids = action_model_predictions [:k ]
542+
543+ if benchmark_testing :
544+ # if we are testing on benchmark, we need to ensure that the gt_vn is in the top k predictions
545+ # if not, we remove the last prediction and add the gt_vn
546+ if gt_vn not in answer_ids :
547+ answer_ids .pop ()
548+ answer_ids .append (gt_vn )
549+
550+ # let's shuffle answer_ids so that the gt_vn is not always at the end
551+ random .shuffle (answer_ids )
552+
531553 answers = []
532554 for answer_id in answer_ids :
533555 answer = parse_vn_ids (answer_id , gt_vn , narration , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps )
@@ -566,7 +588,8 @@ def generate_multi_choice(self,
566588 mapping_vn2narration ,
567589 verb_maps ,
568590 noun_maps ,
569- is_train = True
591+ is_train = True ,
592+ benchmark_testing = False
570593 ):
571594 """
572595 Generate k multiple choices from gt_vn pairs
@@ -578,7 +601,7 @@ def generate_multi_choice(self,
578601 if is_train :
579602 return self .train_generate (gt_vn , avion_predictions , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps )
580603 else :
581- return self .test_generate (gt_vn , avion_predictions , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps )
604+ return self .test_generate (gt_vn , avion_predictions , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps , benchmark_testing = benchmark_testing )
582605
583606def get_frame_ids (start_frame , end_frame , num_segments = 32 , jitter = True ):
584607 frame_ids = np .convolve (np .linspace (start_frame , end_frame , num_segments + 1 ), [0.5 , 0.5 ], mode = 'valid' )
0 commit comments