Skip to content

Commit d3bf3e6

Browse files
author
Ye Shaokai
committed
better benchmark code
1 parent 6fcd61c commit d3bf3e6

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

llava/action/benchmark.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,27 @@
33
# benchmark gpt-4o on random_mcq_top5_500
44
from llava.action.chatgpt_utils import GPTInferenceAnnotator
55

6-
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
7-
annotation_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
8-
avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
9-
tim_prediction_file = '/data/epic_kitchen/TIM_PREDS/tim_pred_ids_val.json'
10-
n_frames = 4
6+
# root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
7+
# annotation_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
8+
# avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
9+
# tim_prediction_file = '/data/epic_kitchen/TIM_PREDS/tim_pred_ids_val.json'
10+
11+
root = '/data/shaokai/EK100/'
12+
annotation_file = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv'
13+
avion_prediction_file = '/data/shaokai/AVION_PREDS/avion_pred_ids_val.json'
14+
tim_prediction_file = '/data/shaokai/TIM_PREDS/tim_pred_ids_val.json'
15+
16+
17+
n_frames = 16
1118
topk = 5
1219
action_representation = 'GT_random_narration'
13-
#gpt_model = 'gpt-4o-mini-2024-07-18'
14-
gpt_model = 'gpt-4o-2024-08-06'
20+
gpt_model = 'gpt-4o-mini-2024-07-18'
21+
#gpt_model = 'gpt-4o-2024-08-06'
1522
perspective = 'first_person'
1623
benchmark_testing = True
1724

1825

19-
def benchmark_avion_mcq(n_samples):
26+
def benchmark_avion_mcq(n_samples, gpt_model):
2027

2128
inferencer = GPTInferenceAnnotator(gpt_model,
2229
root,
@@ -29,9 +36,10 @@ def benchmark_avion_mcq(n_samples):
2936
perspective = perspective,
3037
benchmark_testing = benchmark_testing,
3138
topk = topk)
32-
inferencer.multi_process_run(n_samples)
39+
inferencer.multi_process_run(n_samples = n_samples,
40+
offset = 0)
3341

34-
def benchmark_tim_mcq(n_samples):
42+
def benchmark_tim_mcq(n_samples, gpt_model):
3543

3644
inferencer = GPTInferenceAnnotator(gpt_model,
3745
root,
@@ -44,9 +52,9 @@ def benchmark_tim_mcq(n_samples):
4452
perspective = perspective,
4553
benchmark_testing = benchmark_testing,
4654
topk = topk)
47-
inferencer.multi_process_run(n_samples)
55+
inferencer.multi_process_run(n_samples = n_samples, offset = 0)
4856

49-
def benchmark_random_mcq(n_samples):
57+
def benchmark_random_mcq(n_samples, gpt_model):
5058
inferencer = GPTInferenceAnnotator(gpt_model,
5159
root,
5260
annotation_file,
@@ -59,10 +67,10 @@ def benchmark_random_mcq(n_samples):
5967
benchmark_testing = benchmark_testing,
6068
topk = topk)
6169

62-
inferencer.multi_process_run(n_samples)
70+
inferencer.multi_process_run(n_samples = n_samples, offset = 0)
6371

6472

6573
if __name__ == '__main__':
66-
#benchmark_avion_mcq(100)
67-
benchmark_tim_mcq(100)
68-
#benchmark_random_mcq(100)
74+
benchmark_avion_mcq(-1)
75+
benchmark_tim_mcq(-1)
76+
benchmark_random_mcq(-1)

llava/action/chatgpt_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,9 +481,8 @@ def multi_process_run(self, offset= 0, n_samples = -1, disable_api_calling = Fal
481481
if combined_results and 'mc_' in self.question_type:
482482
calculation = calculate_gpt_accuracy(data = combined_results)
483483

484-
prefix = self.gen_type
485484
assert n_samples != -1
486-
checkpoint_name = f"{prefix}_{self.action_representation}_top{self.topk}_{self.clip_length}f_{n_samples}samples.json"
485+
checkpoint_name = f"{self.gpt_model}_{self.gen_type}_{self.action_representation}_top{self.topk}_{self.clip_length}f_{n_samples}samples.json"
487486

488487
if self.do_visualization:
489488
self.checkpoint(combined_results, os.path.join(self.vis_folder, checkpoint_name))

0 commit comments

Comments
 (0)