Skip to content

Commit 8219a28

Browse files
author
Ye Shaokai
committed
updates
1 parent 8f4883d commit 8219a28

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

action/chatgpt_utils.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,21 @@ class GPTInferenceAnnotator(ChatGPT):
149149
Given the images, this class will annotate the video frames
150150
"""
151151

152-
def __init__(self, root, prediction_save_folder, clip_length = 4, debug = False):
152+
def __init__(self,
153+
root,
154+
prediction_save_folder,
155+
clip_length = 4,
156+
debug = False,
157+
topk = 10
158+
):
153159
super().__init__(clip_length = clip_length)
154160
self.root = root
155161
self.prediction_save_folder = prediction_save_folder
156162
self.prediction_analysis = PredictionAnalysis(self.prediction_save_folder)
157163
self.prediction_analysis.load()
158164
self.data = self.prediction_analysis.data
159165
self.debug = debug
166+
self.topk = topk
160167

161168
def multi_process_run(self):
162169
prediction_analysis = PredictionAnalysis(self.prediction_save_folder)
@@ -187,6 +194,8 @@ def parse_item(self, item):
187194

188195
gt_name = item['gt_name']
189196
avion_predictions = item['avion_preds']['predictions']
197+
assert self.topk <= len(avion_predictions)
198+
avion_predictions = avion_predictions[:self.topk]
190199
# _avion_predictions = [e.replace(':', ' ', 1) for e in avion_predictions]
191200
# if gt_name not in _avion_predictions:
192201
# print ('gt_name not in avion_predictions')
@@ -458,12 +467,18 @@ def explore_wrong_examples(root, prediction_save_folder):
458467
debug = True)
459468
annotator.explore_wrong_examples()
460469

461-
def multi_process_inference(root, prediction_save_folder):
470+
def multi_process_inference(root,
471+
prediction_save_folder,
472+
clip_length = 4,
473+
topk = 10,
474+
debug = False):
462475

463476
annotator = GPTInferenceAnnotator(root,
464477
prediction_save_folder,
465-
clip_length = 4,
466-
debug = True)
478+
clip_length = clip_length,
479+
debug = debug,
480+
topk = topk)
481+
467482
annotator.multi_process_run()
468483

469484
if __name__ == '__main__':
@@ -473,7 +488,8 @@ def multi_process_inference(root, prediction_save_folder):
473488
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
474489
pred_folder = '/data/epic_kitchen/llavavideo_avion_mc_top10_5epoch_preds'
475490

476-
multi_process_annotate(train_file_path, root)
491+
#multi_process_annotate(train_file_path, root)
477492
#explore_wrong_examples(root, pred_folder)
478-
479-
#multi_process_inference(root, pred_folder)
493+
multi_process_inference(root, pred_folder, debug = False,
494+
clip_length = 4,
495+
topk = 5)

0 commit comments

Comments
 (0)