@@ -411,7 +411,8 @@ def generate_multi_choice(self,
411411 mapping_vn2narration ,
412412 verb_maps ,
413413 noun_maps ,
414- is_train = True
414+ is_train = True ,
415+ benchmark_testing = False
415416 ):
416417
417418 """
@@ -425,7 +426,7 @@ def generate_multi_choice(self,
425426 if is_train :
426427 return self .train_generate (gt_vn , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps )
427428 else :
428- return self .test_generate (gt_vn , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps )
429+ return self .test_generate (gt_vn , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps , benchmark_testing = benchmark_testing )
429430
430431 def train_generate (self , gt_vn , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps ):
431432 # letters as A, B, C, D, .. Note we maximally support 26 letters
@@ -522,12 +523,34 @@ def train_generate(self, gt_vn, avion_predictions, narration, k, action_represen
522523 }
523524 return mc_data
524525
525- def test_generate (self , gt_vn , avion_predictions , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps ):
526+ def test_generate (self ,
527+ gt_vn ,
528+ action_model_predictions ,
529+ narration ,
530+ k ,
531+ action_representation ,
532+ n_narrations ,
533+ labels ,
534+ mapping_vn2narration ,
535+ verb_maps ,
536+ noun_maps ,
537+ benchmark_testing = False
538+ ):
526539 """
527540 During testing, we use the top k predictions from avion. No randomness. We do not mix the gt_vn with the avion predictions
528541 """
529-
530- answer_ids = avion_predictions [:k ]
542+ answer_ids = action_model_predictions [:k ]
543+
544+ if benchmark_testing :
545+ # if we are testing on benchmark, we need to ensure that the gt_vn is in the top k predictions
546+ # if not, we remove the last prediction and add the gt_vn
547+ if gt_vn not in answer_ids :
548+ answer_ids .pop ()
549+ answer_ids .append (gt_vn )
550+
551+ # let's shuffle answer_ids so that the gt_vn is not always at the end
552+ random .shuffle (answer_ids )
553+
531554 answers = []
532555 for answer_id in answer_ids :
533556 answer = parse_vn_ids (answer_id , gt_vn , narration , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps )
@@ -566,7 +589,8 @@ def generate_multi_choice(self,
566589 mapping_vn2narration ,
567590 verb_maps ,
568591 noun_maps ,
569- is_train = True
592+ is_train = True ,
593+ benchmark_testing = False
570594 ):
571595 """
572596 Generate k multiple choices from gt_vn pairs
@@ -578,7 +602,7 @@ def generate_multi_choice(self,
578602 if is_train :
579603 return self .train_generate (gt_vn , avion_predictions , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps )
580604 else :
581- return self .test_generate (gt_vn , avion_predictions , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps )
605+ return self .test_generate (gt_vn , avion_predictions , narration , k , action_representation , n_narrations , labels , mapping_vn2narration , verb_maps , noun_maps , benchmark_testing = benchmark_testing )
582606
583607def get_frame_ids (start_frame , end_frame , num_segments = 32 , jitter = True ):
584608 frame_ids = np .convolve (np .linspace (start_frame , end_frame , num_segments + 1 ), [0.5 , 0.5 ], mode = 'valid' )
0 commit comments