55import glob
66import json
77import os
8+ import re
9+
10+ def process_raw_pred (raw_pred ):
11+ matches = re .findall (r"[A-Z]\.\s(.+)" , raw_pred )
12+
13+ if 'None' in raw_pred :
14+ return raw_pred .replace ('None. ' , '' )
15+
16+ if matches :
17+ # Get the last match
18+ last_match = matches [- 1 ]
19+ # Remove a trailing period and anything after it
20+ last_match = re .sub (r"\.\s*.*$" , "" , last_match )
21+ return last_match
22+ else :
23+ return raw_pred
24+
825# root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
926# annotation_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
1027# avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
2340benchmark_testing = True
2441
2542
26- def benchmark_avion_mcq (n_samples , gpt_model ):
43+ def benchmark_avion_mcq (n_samples , gpt_model , action_representation , benchmark_testing = True , n_frames = 8 ):
2744
2845 inferencer = GPTInferenceAnnotator (gpt_model ,
2946 root ,
@@ -39,7 +56,7 @@ def benchmark_avion_mcq(n_samples, gpt_model):
3956 inferencer .multi_process_run (n_samples = n_samples ,
4057 offset = 0 )
4158
42- def benchmark_tim_mcq (n_samples , gpt_model ):
59+ def benchmark_tim_mcq (n_samples , gpt_model , action_representation , benchmark_testing = True , n_frames = 8 ):
4360
4461 inferencer = GPTInferenceAnnotator (gpt_model ,
4562 root ,
@@ -54,7 +71,7 @@ def benchmark_tim_mcq(n_samples, gpt_model):
5471 topk = topk )
5572 inferencer .multi_process_run (n_samples = n_samples , offset = 0 )
5673
57- def benchmark_random_mcq (n_samples , gpt_model ):
74+ def benchmark_random_mcq (n_samples , gpt_model , action_representation , benchmark_testing = True , n_frames = 8 ):
5875 inferencer = GPTInferenceAnnotator (gpt_model ,
5976 root ,
6077 annotation_file ,
@@ -75,18 +92,34 @@ def calcuate_acc_from_jsons(json_folder):
7592 print (file )
7693 preds = json .load (open (file ))
7794 correct = 0
95+ something = 0
7896 for k ,v in preds .items ():
97+ options = v ['options' ]
98+ options = [process_raw_pred (e ) for e in options ]
99+
100+ #assert v['gt_name'] in options, f"{v['gt_name']} not in {options}"
101+ if v ['gt_name' ] not in options :
102+ print ('what?' , options )
103+ print ('what?' , v )
104+ break
105+
79106 if v ['gt_name' ] == v ['chatgpt_answer' ]:
80107 correct += 1
108+ else :
109+ pass
110+ #print ('wrong prediction! pred: gt', v['chatgpt_answer'] + "," + v['gt_name'])
81111 print ('acc ' , correct / len (preds ))
112+ print ('gt not in options' , something )
82113
83114
84115
85116if __name__ == '__main__' :
86- # benchmark_avion_mcq(-1, 'gpt-4o-mini-2024-07-18')
87- # benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18')
88- # benchmark_random_mcq(-1, 'gpt-4o-mini-2024-07-18')
89- # benchmark_avion_mcq(-1, 'gpt-4o-2024-08-06')
90- # benchmark_tim_mcq(-1, 'gpt-4o-2024-08-06')
91- # benchmark_random_mcq(-1, 'gpt-4o-2024-08-06')
92- calcuate_acc_from_jsons ('gpt_full_benchmark_results' )
117+ # benchmark_avion_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
118+ # benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
119+ # benchmark_random_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
120+ # benchmark_avion_mcq(-1, 'gpt-4o-2024-08-06', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
121+ # benchmark_tim_mcq(-1, 'gpt-4o-2024-08-06', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
122+ # benchmark_random_mcq(-1, 'gpt-4o-2024-08-06', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
123+ benchmark_tim_mcq (1 , 'gpt-4o-mini-2024-07-18' , 'official_key' , benchmark_testing = False , n_frames = 16 )
124+ #benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = False, n_frames = 16)
125+ #calcuate_acc_from_jsons('gpt_EK100_results')
0 commit comments