Skip to content

Commit c614e7e

Browse files
author
Ye Shaokai
committed
updates
1 parent 9491494 commit c614e7e

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

llava/action/benchmark.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# benchmark gpt-4o on tim_mcq_top5_500
33
# benchmark gpt-4o on random_mcq_top5_500
44
from llava.action.chatgpt_utils import GPTInferenceAnnotator
5-
5+
import glob
6+
import json
7+
import os
68
# root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
79
# annotation_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
810
# avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
@@ -14,7 +16,7 @@
1416
tim_prediction_file = '/data/shaokai/TIM_PREDS/tim_pred_ids_val.json'
1517

1618

17-
n_frames = 16
19+
n_frames = 8
1820
topk = 5
1921
action_representation = 'GT_random_narration'
2022
perspective = 'first_person'
@@ -67,12 +69,24 @@ def benchmark_random_mcq(n_samples, gpt_model):
6769

6870
inferencer.multi_process_run(n_samples = n_samples, offset = 0)
6971

72+
def calcuate_acc_from_jsons(json_folder):
73+
files = glob.glob(os.path.join(json_folder, '*.json'))
74+
for file in files:
75+
print (file)
76+
preds = json.load(open(file))
77+
correct = 0
78+
for k,v in preds.items():
79+
if v['gt_name'] == v['chatgpt_answer']:
80+
correct+=1
81+
print ('acc ', correct/len(preds))
82+
83+
7084

7185
if __name__ == '__main__':
72-
benchmark_avion_mcq(-1, 'gpt-4o-mini-2024-07-18')
73-
benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18')
74-
benchmark_random_mcq(-1, 'gpt-4o-mini-2024-07-18')
75-
benchmark_avion_mcq(-1, 'gpt-4o-2024-08-06')
76-
benchmark_tim_mcq(-1, 'gpt-4o-2024-08-06')
77-
benchmark_random_mcq(-1, 'gpt-4o-2024-08-06')
78-
86+
# benchmark_avion_mcq(-1, 'gpt-4o-mini-2024-07-18')
87+
# benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18')
88+
# benchmark_random_mcq(-1, 'gpt-4o-mini-2024-07-18')
89+
# benchmark_avion_mcq(-1, 'gpt-4o-2024-08-06')
90+
# benchmark_tim_mcq(-1, 'gpt-4o-2024-08-06')
91+
# benchmark_random_mcq(-1, 'gpt-4o-2024-08-06')
92+
calcuate_acc_from_jsons('gpt_full_benchmark_results')

llava/action/chatgpt_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,6 @@ def multi_process_run(self, offset= 0, n_samples = -1, disable_api_calling = Fal
483483
if combined_results and 'mc_' in self.question_type:
484484
calculation = calculate_gpt_accuracy(data = combined_results)
485485

486-
assert n_samples != -1
487486
checkpoint_name = f"{self.gpt_model}_{self.gen_type}_{self.action_representation}_top{self.topk}_{self.clip_length}f_{n_samples}samples.json"
488487

489488
if self.do_visualization:

0 commit comments

Comments
 (0)