1010import torch
1111import cv2
1212from pathlib import Path
13+ from tqdm import tqdm
1314from action .prediction_analysis import PredictionAnalysis
1415
1516client = openai .OpenAI (api_key = os .environ .get ("OPENAI_API_KEY" ))
@@ -186,12 +187,7 @@ def multi_process_run(self):
186187 def parse_item (self , item ):
187188
188189 gt_name = item ['gt_name' ]
189- avion_predictions = item ['avion_preds' ]['predictions' ]
190- # _avion_predictions = [e.replace(':', ' ', 1) for e in avion_predictions]
191- # if gt_name not in _avion_predictions:
192- # print ('gt_name not in avion_predictions')
193- # else:
194- # print ('gt_name in avion_predictions')
190+ avion_predictions = item ['avion_preds' ]['predictions' ]
195191
196192 vid_path = item ['vid_path' ][0 ]
197193 start_second = item ['start_second' ]
@@ -215,7 +211,7 @@ def run(self, indices):
215211 data_batch = {i : self .data [i ] for i in range (len (self .data )) if i in indices }
216212 ret = {}
217213
218- for k ,v in data_batch .items ():
214+ for k ,v in tqdm ( data_batch .items () ):
219215 parsed_item = self .parse_item (v )
220216 start_timestamp = parsed_item ['start_second' ]
221217 end_timestamp = parsed_item ['end_second' ]
@@ -378,7 +374,7 @@ def run(self, indices):
378374 data_batch = [self .data [i ] for i in range (len (self .data )) if i in indices ]
379375
380376 ret = {}
381- for index in indices :
377+ for index in tqdm ( indices ) :
382378 item = self .data [index ]
383379 start_timestamp = item ['start_timestamp' ]
384380 end_timestamp = item ['end_timestamp' ]
@@ -444,36 +440,54 @@ def annotate(self, images, data_item):
444440 return response .choices [0 ].message .parsed
445441
446442
447- def multi_process_annotate (train_file_path , root ):
443+ def multi_process_annotate (train_file_path , root , debug = False ):
448444 annotator = GPTAugmentationAnnotator (train_file_path ,
449445 root ,
450446 clip_length = 4 ,
451- debug = True )
447+ debug = debug )
452448 results = annotator .multi_process_run ()
453449
454- def explore_wrong_examples (root , prediction_save_folder ):
450+ def explore_wrong_examples (root , prediction_save_folder , debug = False ):
455451 annotator = GPTInferenceAnnotator (root ,
456452 prediction_save_folder ,
457453 clip_length = 4 ,
458- debug = True )
454+ debug = debug )
459455 annotator .explore_wrong_examples ()
460456
461- def multi_process_inference (root , prediction_save_folder ):
457+ def multi_process_inference (root , prediction_save_folder , debug = False ):
462458
463459 annotator = GPTInferenceAnnotator (root ,
464460 prediction_save_folder ,
465- clip_length = 4 ,
466- debug = True )
461+ clip_length = 32 ,
462+ debug = debug )
463+
467464 annotator .multi_process_run ()
468465
466+ def calculate_gpt_accuracy (path ):
467+ with open (path , 'r' ) as f :
468+ data = json .load (f )
469+
470+ keys = list (data .keys ())
471+ print ('length of the data' , len (keys ))
472+
473+ correct_count = 0
474+ for k ,v in data .items ():
475+ gt_name = v ['gt_name' ]
476+ chatgpt_answer = v ['chatgpt_answer' ]
477+ if gt_name == chatgpt_answer :
478+ correct_count += 1
479+ else :
480+ print (chatgpt_answer , gt_name )
481+ print ('accuracy' , correct_count / len (keys ))
482+
469483if __name__ == '__main__' :
470484 #train_file_path = '/storage-rcp-pure/upmwmathis_scratch/shaokai/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl'
471485 #root = '/storage-rcp-pure/upmwmathis_scratch/shaokai/EK100'
472486 train_file_path = '/data/EK100_inst_train/avion_mc_top10/train_convs_narration.jsonl'
473487 root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
474488 pred_folder = '/data/epic_kitchen/llavavideo_avion_mc_top10_5epoch_preds'
475489
476- multi_process_annotate (train_file_path , root )
490+ # multi_process_annotate(train_file_path, root)
477491 #explore_wrong_examples(root, pred_folder)
478-
479- #multi_process_inference(root, pred_folder )
492+ #multi_process_inference(root, pred_folder, debug = False)
493+ calculate_gpt_accuracy ( 'valset_chatgpt_inference_results/gpt-4o-avion_top10_4frames.json' )
0 commit comments