@@ -149,14 +149,21 @@ class GPTInferenceAnnotator(ChatGPT):
149149 Given the images, this class will annotate the video frames
150150 """
151151
152- def __init__ (self , root , prediction_save_folder , clip_length = 4 , debug = False ):
152+ def __init__ (self ,
153+ root ,
154+ prediction_save_folder ,
155+ clip_length = 4 ,
156+ debug = False ,
157+ topk = 10
158+ ):
153159 super ().__init__ (clip_length = clip_length )
154160 self .root = root
155161 self .prediction_save_folder = prediction_save_folder
156162 self .prediction_analysis = PredictionAnalysis (self .prediction_save_folder )
157163 self .prediction_analysis .load ()
158164 self .data = self .prediction_analysis .data
159165 self .debug = debug
166+ self .topk = topk
160167
161168 def multi_process_run (self ):
162169 prediction_analysis = PredictionAnalysis (self .prediction_save_folder )
@@ -187,6 +194,8 @@ def parse_item(self, item):
187194
188195 gt_name = item ['gt_name' ]
189196 avion_predictions = item ['avion_preds' ]['predictions' ]
197+ assert self .topk <= len (avion_predictions )
198+ avion_predictions = avion_predictions [:self .topk ]
190199 # _avion_predictions = [e.replace(':', ' ', 1) for e in avion_predictions]
191200 # if gt_name not in _avion_predictions:
192201 # print ('gt_name not in avion_predictions')
@@ -458,12 +467,18 @@ def explore_wrong_examples(root, prediction_save_folder):
458467 debug = True )
459468 annotator .explore_wrong_examples ()
460469
461- def multi_process_inference (root , prediction_save_folder ):
470+ def multi_process_inference (root ,
471+ prediction_save_folder ,
472+ clip_length = 4 ,
473+ topk = 10 ,
474+ debug = False ):
462475
463476 annotator = GPTInferenceAnnotator (root ,
464477 prediction_save_folder ,
465- clip_length = 4 ,
466- debug = True )
478+ clip_length = clip_length ,
479+ debug = debug ,
480+ topk = topk )
481+
467482 annotator .multi_process_run ()
468483
469484if __name__ == '__main__' :
@@ -473,7 +488,8 @@ def multi_process_inference(root, prediction_save_folder):
473488 root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
474489 pred_folder = '/data/epic_kitchen/llavavideo_avion_mc_top10_5epoch_preds'
475490
476- multi_process_annotate (train_file_path , root )
491+ # multi_process_annotate(train_file_path, root)
477492 #explore_wrong_examples(root, pred_folder)
478-
479- #multi_process_inference(root, pred_folder)
493+ multi_process_inference (root , pred_folder , debug = False ,
494+ clip_length = 4 ,
495+ topk = 5 )
0 commit comments