1919import base64
2020from pathlib import Path
2121import traceback
22+ import cv2
2223
2324
2425client = openai .OpenAI (api_key = os .environ .get ("OPENAI_API_KEY" ))
@@ -348,7 +349,8 @@ def __init__(self,
348349 debug = False ,
349350 topk = 10 ,
350351 perspective = 'first_person' ,
351- benchmark_testing = False
352+ benchmark_testing = False ,
353+ do_visualization = False
352354 ):
353355 """
354356 Parameters
@@ -373,17 +375,31 @@ def __init__(self,
373375 self .perspective = perspective
374376 self .benchmark_testing = benchmark_testing
375377 assert gen_type in ['avion' , 'tim' , 'random' ]
376-
378+
377379 if gen_type == 'avion' or gen_type == 'tim' :
378380 self .mc_generator = ActionMultiChoiceGenerator (self .annotation_root )
381+ assert os .path .exists (self .prediction_file )
379382 with open (self .prediction_file , 'r' ) as f :
380383 self .action_model_predictions = json .load (f )
381384 else :
382385 self .mc_generator = RandomMultiChoiceGenerator (self .annotation_root )
383386
384-
387+ self .do_visualization = do_visualization
388+ self .vis_folder = f"{ self .gpt_model } _{ self .gen_type } _{ self .question_type } _{ self .perspective } "
389+ os .makedirs (self .vis_folder , exist_ok = True )
385390 self .data = self .init_data ()
386-
391+
392+ def save_visualization (self ,frames , uid ):
393+ """
394+ Save the frames to the out_dir
395+ """
396+ out_dir = Path (self .vis_folder )
397+ out_dir .mkdir (parents = True , exist_ok = True )
398+ sub_folder = out_dir / uid
399+ sub_folder .mkdir (parents = True , exist_ok = True )
400+ for idx , frame in enumerate (frames ):
401+ cv2 .imwrite (str (sub_folder / f"{ uid } _{ idx } .jpg" ), cv2 .cvtColor (frame , cv2 .COLOR_BGR2RGB ))
402+
387403
388404 def init_data (self ):
389405 ret = {}
@@ -435,41 +451,45 @@ def init_data(self):
435451 'end_second' : end_second ,
436452 'vid_path' : vid_path
437453 }
438-
439454 return ret
440455
441- def multi_process_run (self , n_samples = - 1 ):
442- # to initialize it
456+ def multi_process_run (self , offset = 0 , n_samples = - 1 , disable_api_calling = False ):
457+ # inside GPT inference annotator
443458
444- if n_samples != - 1 :
445- indices = list (range (len (self .data )))[:n_samples ]
459+ if n_samples == - 1 :
460+ # do not use offset if n_samples is -1
461+ assert offset == 0
446462
463+ if n_samples != - 1 :
464+ indices = list (range (len (self .data )))[offset :offset + n_samples ]
447465 num_chunks = os .cpu_count () if not self .debug else 2
448466
449467 indices_groups = self .split_indices (indices , num_chunks )
450468
451469 with ProcessPoolExecutor (max_workers = num_chunks ) as executor :
452470 # Pass additional arguments to the function
453- futures = [executor .submit (self .run , group ) for group in indices_groups ]
471+ futures = [executor .submit (self .run , group , disable_api_calling ) for group in indices_groups ]
454472
455473 # Wait for all futures to complete
456474 combined_results = {}
457475 for future in futures :
458476 result_dict = future .result ()
459477 combined_results .update (result_dict )
460-
478+ print ( combined_results )
461479 if self .debug :
462480 print (combined_results )
463-
464- calculation = calculate_gpt_accuracy (data = combined_results )
481+ if combined_results and 'mc_' in self . question_type :
482+ calculation = calculate_gpt_accuracy (data = combined_results )
465483
466484 prefix = self .gen_type
467485 assert n_samples != - 1
468486 checkpoint_name = f"{ prefix } _{ self .action_representation } _top{ self .topk } _{ self .clip_length } f_{ n_samples } samples.json"
469487
488+ if self .do_visualization :
489+ self .checkpoint (combined_results , os .path .join (self .vis_folder , checkpoint_name ))
470490 self .checkpoint (combined_results , checkpoint_name )
471491
472- def run (self , indices = None ):
492+ def run (self , indices = None , disable_api_calling = False ):
473493 if indices is None :
474494 data_batch = {i : self .data [i ] for i in range (len (self .data )) if i in list (range (len (self .data )))}
475495 else :
@@ -481,22 +501,36 @@ def run(self, indices=None):
481501 start_timestamp = v ['start_second' ]
482502 end_timestamp = v ['end_second' ]
483503 vid_path = v ['vid_path' ]
504+ _id = v ['vid_path' ].replace ('/' , '-' )
505+ uid = f"{ _id } _{ start_timestamp } _{ end_timestamp } "
484506
485507 frames , time_meta = self .extract_frames (vid_path , start_timestamp , end_timestamp )
486- try :
508+
509+ if self .do_visualization :
510+ # the output folder should reflect the gen type, question type and perspective
511+ # and the question type
512+ self .save_visualization (frames , uid )
513+ if disable_api_calling :
514+ break
515+ try :
487516 parsed_answer = self .predict_images (frames , v )
488517 except Exception as e :
489518 # get full stack trace
490- traceback .print_exc ()
491-
519+ traceback .print_exc ()
492520 print ("An exception occurred: " , e )
493521
494522 predicted_answer = parsed_answer .answer
495523 gt_name = v ['gt_answer' ]
496524 ret [k ] = {
525+ "uid" : uid ,
497526 'gt_name' : gt_name ,
498- 'chatgpt_answer' : process_raw_pred (predicted_answer ),
527+ "options" : v ['options' ],
528+ 'chatgpt_answer' : process_raw_pred (predicted_answer ) if 'mc_' in self .question_type else predicted_answer
499529 }
530+ if self .do_visualization :
531+ # save ret to the output folder
532+ self .checkpoint (ret , os .path .join (self .vis_folder , uid , 'inference_results.json' ))
533+
500534 if self .debug :
501535 break
502536
@@ -529,9 +563,7 @@ def predict_images(self, images, parsed_item):
529563
530564 if 'o1' in self .gpt_model :
531565 system_prompt += format_prompt
532-
533- #print (system_prompt)
534-
566+
535567 if self .handobj_root is not None :
536568 system_prompt += f"""To further assist you, we mark hands and object when they are visible. The left hand is marked with a bounding box that contains letter L and the right hand's bounding box contains letter R. The object is marked as 'O'."""
537569
0 commit comments