Skip to content

Commit 20a11bb

Browse files
author
Haozhe Qi
committed
Merge branch 'shaokai/dev' of github.com:yeshaokai/LLaVA-NeXT into shaokai/dev
2 parents ec1e6e0 + e9e3bd1 commit 20a11bb

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

llava/action/benchmark.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,28 @@
22
# benchmark gpt-4o on tim_mcq_top5_500
33
# benchmark gpt-4o on random_mcq_top5_500
44
from llava.action.chatgpt_utils import GPTInferenceAnnotator
5+
import glob
6+
import json
7+
import os
8+
# root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
9+
# annotation_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
10+
# avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
11+
# tim_prediction_file = '/data/epic_kitchen/TIM_PREDS/tim_pred_ids_val.json'
512

6-
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
7-
annotation_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
8-
avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
9-
tim_prediction_file = '/data/epic_kitchen/TIM_PREDS/tim_pred_ids_val.json'
10-
n_frames = 4
13+
root = '/data/shaokai/EK100/'
14+
annotation_file = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv'
15+
avion_prediction_file = '/data/shaokai/AVION_PREDS/avion_pred_ids_val.json'
16+
tim_prediction_file = '/data/shaokai/TIM_PREDS/tim_pred_ids_val.json'
17+
18+
19+
n_frames = 8
1120
topk = 5
1221
action_representation = 'GT_random_narration'
13-
#gpt_model = 'gpt-4o-mini-2024-07-18'
14-
gpt_model = 'gpt-4o-2024-08-06'
1522
perspective = 'first_person'
1623
benchmark_testing = True
1724

1825

19-
def benchmark_avion_mcq(n_samples):
26+
def benchmark_avion_mcq(n_samples, gpt_model):
2027

2128
inferencer = GPTInferenceAnnotator(gpt_model,
2229
root,
@@ -29,9 +36,10 @@ def benchmark_avion_mcq(n_samples):
2936
perspective = perspective,
3037
benchmark_testing = benchmark_testing,
3138
topk = topk)
32-
inferencer.multi_process_run(n_samples)
39+
inferencer.multi_process_run(n_samples = n_samples,
40+
offset = 0)
3341

34-
def benchmark_tim_mcq(n_samples):
42+
def benchmark_tim_mcq(n_samples, gpt_model):
3543

3644
inferencer = GPTInferenceAnnotator(gpt_model,
3745
root,
@@ -44,9 +52,9 @@ def benchmark_tim_mcq(n_samples):
4452
perspective = perspective,
4553
benchmark_testing = benchmark_testing,
4654
topk = topk)
47-
inferencer.multi_process_run(n_samples)
55+
inferencer.multi_process_run(n_samples = n_samples, offset = 0)
4856

49-
def benchmark_random_mcq(n_samples):
57+
def benchmark_random_mcq(n_samples, gpt_model):
5058
inferencer = GPTInferenceAnnotator(gpt_model,
5159
root,
5260
annotation_file,
@@ -59,10 +67,26 @@ def benchmark_random_mcq(n_samples):
5967
benchmark_testing = benchmark_testing,
6068
topk = topk)
6169

62-
inferencer.multi_process_run(n_samples)
70+
inferencer.multi_process_run(n_samples = n_samples, offset = 0)
71+
72+
def calcuate_acc_from_jsons(json_folder):
73+
files = glob.glob(os.path.join(json_folder, '*.json'))
74+
for file in files:
75+
print (file)
76+
preds = json.load(open(file))
77+
correct = 0
78+
for k,v in preds.items():
79+
if v['gt_name'] == v['chatgpt_answer']:
80+
correct+=1
81+
print ('acc ', correct/len(preds))
82+
6383

6484

6585
if __name__ == '__main__':
66-
#benchmark_avion_mcq(100)
67-
benchmark_tim_mcq(100)
68-
#benchmark_random_mcq(100)
86+
# benchmark_avion_mcq(-1, 'gpt-4o-mini-2024-07-18')
87+
# benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18')
88+
# benchmark_random_mcq(-1, 'gpt-4o-mini-2024-07-18')
89+
# benchmark_avion_mcq(-1, 'gpt-4o-2024-08-06')
90+
# benchmark_tim_mcq(-1, 'gpt-4o-2024-08-06')
91+
# benchmark_random_mcq(-1, 'gpt-4o-2024-08-06')
92+
calcuate_acc_from_jsons('gpt_full_benchmark_results')

llava/action/chatgpt_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,8 @@ def multi_process_run(self, offset= 0, n_samples = -1, disable_api_calling = Fal
462462

463463
if n_samples != -1:
464464
indices = list(range(len(self.data)))[offset:offset + n_samples]
465+
else:
466+
indices = list(range(len(self.data)))
465467
num_chunks = os.cpu_count() if not self.debug else 2
466468

467469
indices_groups = self.split_indices(indices, num_chunks)
@@ -481,9 +483,7 @@ def multi_process_run(self, offset= 0, n_samples = -1, disable_api_calling = Fal
481483
if combined_results and 'mc_' in self.question_type:
482484
calculation = calculate_gpt_accuracy(data = combined_results)
483485

484-
prefix = self.gen_type
485-
assert n_samples != -1
486-
checkpoint_name = f"{prefix}_{self.action_representation}_top{self.topk}_{self.clip_length}f_{n_samples}samples.json"
486+
checkpoint_name = f"{self.gpt_model}_{self.gen_type}_{self.action_representation}_top{self.topk}_{self.clip_length}f_{n_samples}samples.json"
487487

488488
if self.do_visualization:
489489
self.checkpoint(combined_results, os.path.join(self.vis_folder, checkpoint_name))

0 commit comments

Comments
 (0)