Skip to content

Commit 44ec08f

Browse files
committed
fixed bugs even for benchmark code. The prediction file pointer was wrong
1 parent 79829db commit 44ec08f

File tree

5 files changed

+124
-22
lines changed

5 files changed

+124
-22
lines changed

llava/action/benchmark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
n_frames = 4
1111
topk = 5
1212
action_representation = 'GT_random_narration'
13-
#gpt_model = 'gpt-4o-mini-2024-07-18'
14-
gpt_model = 'gpt-4o-2024-08-06'
13+
gpt_model = 'gpt-4o-mini-2024-07-18'
14+
#gpt_model = 'gpt-4o-2024-08-06'
1515
perspective = 'first_person'
1616
benchmark_testing = True
1717

@@ -37,7 +37,7 @@ def benchmark_tim_mcq(n_samples):
3737
root,
3838
annotation_file,
3939
gen_type = 'tim',
40-
prediction_file = avion_prediction_file,
40+
prediction_file = tim_prediction_file,
4141
clip_length = n_frames,
4242
question_type = 'mc_',
4343
action_representation=action_representation,
@@ -63,6 +63,6 @@ def benchmark_random_mcq(n_samples):
6363

6464

6565
if __name__ == '__main__':
66-
benchmark_avion_mcq(100)
66+
#benchmark_avion_mcq(100)
6767
benchmark_tim_mcq(100)
6868
#benchmark_random_mcq(100)

llava/action/chatgpt_utils.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import base64
2020
from pathlib import Path
2121
import traceback
22+
import cv2
2223

2324

2425
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
@@ -348,7 +349,8 @@ def __init__(self,
348349
debug = False,
349350
topk = 10,
350351
perspective = 'first_person',
351-
benchmark_testing = False
352+
benchmark_testing = False,
353+
do_visualization = False
352354
):
353355
"""
354356
Parameters
@@ -373,17 +375,31 @@ def __init__(self,
373375
self.perspective = perspective
374376
self.benchmark_testing = benchmark_testing
375377
assert gen_type in ['avion', 'tim', 'random']
376-
378+
377379
if gen_type == 'avion' or gen_type == 'tim':
378380
self.mc_generator = ActionMultiChoiceGenerator(self.annotation_root)
381+
assert os.path.exists(self.prediction_file)
382+
print ('prediction_file'*5, self.prediction_file)
379383
with open(self.prediction_file, 'r') as f:
380384
self.action_model_predictions = json.load(f)
381385
else:
382386
self.mc_generator = RandomMultiChoiceGenerator(self.annotation_root)
383387

384-
388+
self.do_visualization = do_visualization
389+
self.vis_folder = f"{self.gpt_model}_{self.gen_type}_{self.question_type}_{self.perspective}"
385390
self.data = self.init_data()
386-
391+
392+
def save_visualization(self,frames, uid):
393+
"""
394+
Save the frames to the out_dir
395+
"""
396+
out_dir = Path(self.vis_folder)
397+
out_dir.mkdir(parents=True, exist_ok=True)
398+
sub_folder = out_dir / uid
399+
sub_folder.mkdir(parents=True, exist_ok=True)
400+
for idx, frame in enumerate(frames):
401+
cv2.imwrite(str(sub_folder / f"{uid}_{idx}.jpg"), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
402+
387403

388404
def init_data(self):
389405
ret = {}
@@ -438,8 +454,8 @@ def init_data(self):
438454

439455
return ret
440456

441-
def multi_process_run(self, n_samples = -1):
442-
# to initialize it
457+
def multi_process_run(self, n_samples = -1, disable_api_calling = False):
458+
# inside GPT inference annotator
443459

444460
if n_samples != -1:
445461
indices = list(range(len(self.data)))[:n_samples]
@@ -450,7 +466,7 @@ def multi_process_run(self, n_samples = -1):
450466

451467
with ProcessPoolExecutor(max_workers=num_chunks) as executor:
452468
# Pass additional arguments to the function
453-
futures = [executor.submit(self.run, group) for group in indices_groups]
469+
futures = [executor.submit(self.run, group, disable_api_calling) for group in indices_groups]
454470

455471
# Wait for all futures to complete
456472
combined_results = {}
@@ -460,16 +476,18 @@ def multi_process_run(self, n_samples = -1):
460476

461477
if self.debug:
462478
print (combined_results)
463-
464-
calculation = calculate_gpt_accuracy(data = combined_results)
479+
if combined_results and 'mc_' in self.question_type:
480+
calculation = calculate_gpt_accuracy(data = combined_results)
465481

466482
prefix = self.gen_type
467483
assert n_samples != -1
468484
checkpoint_name = f"{prefix}_{self.action_representation}_top{self.topk}_{self.clip_length}f_{n_samples}samples.json"
469485

486+
if self.do_visualization:
487+
self.checkpoint(combined_results, os.path.join(self.vis_folder, checkpoint_name))
470488
self.checkpoint(combined_results, checkpoint_name)
471489

472-
def run(self, indices=None):
490+
def run(self, indices=None, disable_api_calling = False):
473491
if indices is None:
474492
data_batch = {i : self.data[i] for i in range(len(self.data)) if i in list(range(len(self.data)))}
475493
else:
@@ -481,22 +499,36 @@ def run(self, indices=None):
481499
start_timestamp = v['start_second']
482500
end_timestamp = v['end_second']
483501
vid_path = v['vid_path']
502+
_id = v['vid_path'].replace('/', '-')
503+
uid = f"{_id}_{start_timestamp}_{end_timestamp}"
484504

485505
frames, time_meta = self.extract_frames(vid_path, start_timestamp, end_timestamp)
486-
try:
506+
507+
if self.do_visualization:
508+
# the output folder should reflect the gen type, question type and perspective
509+
# and the question type
510+
self.save_visualization(frames, uid)
511+
if disable_api_calling:
512+
break
513+
try:
487514
parsed_answer = self.predict_images(frames, v)
488515
except Exception as e:
489516
# get full stack trace
490-
traceback.print_exc()
491-
517+
traceback.print_exc()
492518
print ("An exception occurred: ", e)
493519

494520
predicted_answer = parsed_answer.answer
495521
gt_name = v['gt_answer']
496522
ret[k] = {
523+
"uid": uid,
497524
'gt_name': gt_name,
498-
'chatgpt_answer': process_raw_pred(predicted_answer),
525+
"options": v['options'],
526+
'chatgpt_answer': process_raw_pred(predicted_answer) if 'mc_' in self.question_type else predicted_answer
499527
}
528+
if self.do_visualization:
529+
# save ret to the output folder
530+
self.checkpoint(ret, os.path.join(self.vis_folder, uid, 'inference_results.json'))
531+
500532
if self.debug:
501533
break
502534

llava/action/ek_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from llava.action.utils import generate_label_map, match_answer
1515
from collections import Counter
1616
import torch.distributed as dist
17-
from llava.action.dataset import VideoMultiChoiceDataset, VideoTemporalMultiChoiceDataset
17+
from llava.action.dataset import VideoMultiChoiceDataset
1818
import torchvision.io as io
1919
import re
2020

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
We need to keep track of the following:
3+
4+
The uid of each segment
5+
6+
The GPT inference of corresponding segment
7+
The LLaVA zero-shot inference of corresponding segment
8+
The Finetuned LLaVA's inference of corresponding segment
9+
10+
Note that in each inference, we should be able to pick the corresponding prompt and checkpoint folder
11+
"""
12+
13+
from llava.action.chatgpt_utils import GPTInferenceAnnotator
14+
15+
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
16+
annotation_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
17+
avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
18+
tim_prediction_file = '/data/epic_kitchen/TIM_PREDS/tim_pred_ids_val.json'
19+
n_frames = 4
20+
topk = 5
21+
action_representation = 'GT_random_narration'
22+
gpt_model = 'gpt-4o-mini-2024-07-18'
23+
#gpt_model = 'gpt-4o-2024-08-06'
24+
perspective = 'first_person'
25+
benchmark_testing = True
26+
27+
28+
29+
def visualize_with_random(n_samples, question_type = 'mc_'):
30+
"""
31+
Here we should test gpt-4o, gpt-4o-mini with different prompts
32+
"""
33+
inferencer = GPTInferenceAnnotator(gpt_model,
34+
root,
35+
annotation_file,
36+
gen_type = 'random',
37+
prediction_file = tim_prediction_file,
38+
clip_length = n_frames,
39+
question_type = question_type,
40+
action_representation=action_representation,
41+
perspective = perspective,
42+
benchmark_testing = benchmark_testing,
43+
do_visualization = True,
44+
topk = topk)
45+
46+
inferencer.multi_process_run(n_samples, disable_api_calling=False)
47+
48+
def visualize_with_gpt_with_tim(n_samples, question_type = 'mc_'):
49+
"""
50+
Here we should test gpt-4o, gpt-4o-mini with different prompts
51+
"""
52+
inferencer = GPTInferenceAnnotator(gpt_model,
53+
root,
54+
annotation_file,
55+
gen_type = 'tim',
56+
prediction_file = tim_prediction_file,
57+
clip_length = n_frames,
58+
question_type = question_type,
59+
action_representation=action_representation,
60+
perspective = perspective,
61+
benchmark_testing = benchmark_testing,
62+
do_visualization = True,
63+
topk = topk)
64+
65+
inferencer.multi_process_run(n_samples, disable_api_calling=False)
66+
67+
68+
if __name__ == '__main__':
69+
#visualize_with_random(1, question_type = "mc_")
70+
visualize_with_gpt_with_tim(1, question_type = "mc_")

llava/action/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def generate_multi_choice(self,
428428
else:
429429
return self.test_generate(gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps, benchmark_testing = benchmark_testing)
430430

431-
def train_generate(self, gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps):
431+
def train_generate(self, gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps, benchmark_testing = False):
432432
# letters as A, B, C, D, .. Note we maximally support 26 letters
433433
letters = [chr(65+i) for i in range(26)][:k]
434434
answer_list = [vn for vn in mapping_vn2narration.keys()]
@@ -463,11 +463,11 @@ def train_generate(self, gt_vn, narration, k, action_representation, n_narration
463463
}
464464
return mc_data
465465

466-
def test_generate(self, gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps):
466+
def test_generate(self, gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps, benchmark_testing = False):
467467
"""
468468
There is no difference between train and test for random generation
469469
"""
470-
return self.train_generate(gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
470+
return self.train_generate(gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps, benchmark_testing = benchmark_testing)
471471

472472
class AvionMultiChoiceGenerator(MultiChoiceGenerator):
473473
"""

0 commit comments

Comments
 (0)