1919import base64
2020from pathlib import Path
2121import traceback
22+ import cv2
2223
2324
2425client = openai .OpenAI (api_key = os .environ .get ("OPENAI_API_KEY" ))
@@ -348,7 +349,8 @@ def __init__(self,
348349 debug = False ,
349350 topk = 10 ,
350351 perspective = 'first_person' ,
351- benchmark_testing = False
352+ benchmark_testing = False ,
353+ do_visualization = False
352354 ):
353355 """
354356 Parameters
@@ -373,17 +375,31 @@ def __init__(self,
373375 self .perspective = perspective
374376 self .benchmark_testing = benchmark_testing
375377 assert gen_type in ['avion' , 'tim' , 'random' ]
376-
378+
377379 if gen_type == 'avion' or gen_type == 'tim' :
378380 self .mc_generator = ActionMultiChoiceGenerator (self .annotation_root )
381+ assert os .path .exists (self .prediction_file )
382+ print ('prediction_file' * 5 , self .prediction_file )
379383 with open (self .prediction_file , 'r' ) as f :
380384 self .action_model_predictions = json .load (f )
381385 else :
382386 self .mc_generator = RandomMultiChoiceGenerator (self .annotation_root )
383387
384-
388+ self .do_visualization = do_visualization
389+ self .vis_folder = f"{ self .gpt_model } _{ self .gen_type } _{ self .question_type } _{ self .perspective } "
385390 self .data = self .init_data ()
386-
391+
392+ def save_visualization (self ,frames , uid ):
393+ """
394+ Save the frames to the out_dir
395+ """
396+ out_dir = Path (self .vis_folder )
397+ out_dir .mkdir (parents = True , exist_ok = True )
398+ sub_folder = out_dir / uid
399+ sub_folder .mkdir (parents = True , exist_ok = True )
400+ for idx , frame in enumerate (frames ):
401+ cv2 .imwrite (str (sub_folder / f"{ uid } _{ idx } .jpg" ), cv2 .cvtColor (frame , cv2 .COLOR_BGR2RGB ))
402+
387403
388404 def init_data (self ):
389405 ret = {}
@@ -438,8 +454,8 @@ def init_data(self):
438454
439455 return ret
440456
441- def multi_process_run (self , n_samples = - 1 ):
442- # to initialize it
457+ def multi_process_run (self , n_samples = - 1 , disable_api_calling = False ):
458+ # inside GPT inference annotator
443459
444460 if n_samples != - 1 :
445461 indices = list (range (len (self .data )))[:n_samples ]
@@ -450,7 +466,7 @@ def multi_process_run(self, n_samples = -1):
450466
451467 with ProcessPoolExecutor (max_workers = num_chunks ) as executor :
452468 # Pass additional arguments to the function
453- futures = [executor .submit (self .run , group ) for group in indices_groups ]
469+ futures = [executor .submit (self .run , group , disable_api_calling ) for group in indices_groups ]
454470
455471 # Wait for all futures to complete
456472 combined_results = {}
@@ -460,16 +476,18 @@ def multi_process_run(self, n_samples = -1):
460476
461477 if self .debug :
462478 print (combined_results )
463-
464- calculation = calculate_gpt_accuracy (data = combined_results )
479+ if combined_results and 'mc_' in self . question_type :
480+ calculation = calculate_gpt_accuracy (data = combined_results )
465481
466482 prefix = self .gen_type
467483 assert n_samples != - 1
468484 checkpoint_name = f"{ prefix } _{ self .action_representation } _top{ self .topk } _{ self .clip_length } f_{ n_samples } samples.json"
469485
486+ if self .do_visualization :
487+ self .checkpoint (combined_results , os .path .join (self .vis_folder , checkpoint_name ))
470488 self .checkpoint (combined_results , checkpoint_name )
471489
472- def run (self , indices = None ):
490+ def run (self , indices = None , disable_api_calling = False ):
473491 if indices is None :
474492 data_batch = {i : self .data [i ] for i in range (len (self .data )) if i in list (range (len (self .data )))}
475493 else :
@@ -481,22 +499,36 @@ def run(self, indices=None):
481499 start_timestamp = v ['start_second' ]
482500 end_timestamp = v ['end_second' ]
483501 vid_path = v ['vid_path' ]
502+ _id = v ['vid_path' ].replace ('/' , '-' )
503+ uid = f"{ _id } _{ start_timestamp } _{ end_timestamp } "
484504
485505 frames , time_meta = self .extract_frames (vid_path , start_timestamp , end_timestamp )
486- try :
506+
507+ if self .do_visualization :
508+ # the output folder should reflect the gen type, question type and perspective
509+ # and the question type
510+ self .save_visualization (frames , uid )
511+ if disable_api_calling :
512+ break
513+ try :
487514 parsed_answer = self .predict_images (frames , v )
488515 except Exception as e :
489516 # get full stack trace
490- traceback .print_exc ()
491-
517+ traceback .print_exc ()
492518 print ("An exception occurred: " , e )
493519
494520 predicted_answer = parsed_answer .answer
495521 gt_name = v ['gt_answer' ]
496522 ret [k ] = {
523+ "uid" : uid ,
497524 'gt_name' : gt_name ,
498- 'chatgpt_answer' : process_raw_pred (predicted_answer ),
525+ "options" : v ['options' ],
526+ 'chatgpt_answer' : process_raw_pred (predicted_answer ) if 'mc_' in self .question_type else predicted_answer
499527 }
528+ if self .do_visualization :
529+ # save ret to the output folder
530+ self .checkpoint (ret , os .path .join (self .vis_folder , uid , 'inference_results.json' ))
531+
500532 if self .debug :
501533 break
502534
0 commit comments