Skip to content

Commit b84d560

Browse files
author
Ye Shaokai
committed
Added vis utils
1 parent e97bccb commit b84d560

File tree

6 files changed

+394
-19
lines changed

6 files changed

+394
-19
lines changed

llava/action/benchmark.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@
55
import glob
66
import json
77
import os
8+
import re
9+
10+
def process_raw_pred(raw_pred):
11+
matches = re.findall(r"[A-Z]\.\s(.+)", raw_pred)
12+
13+
if 'None' in raw_pred:
14+
return raw_pred.replace('None. ', '')
15+
16+
if matches:
17+
# Get the last match
18+
last_match = matches[-1]
19+
# Remove a trailing period and anything after it
20+
last_match = re.sub(r"\.\s*.*$", "", last_match)
21+
return last_match
22+
else:
23+
return raw_pred
24+
825
# root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
926
# annotation_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
1027
# avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
@@ -23,7 +40,7 @@
2340
benchmark_testing = True
2441

2542

26-
def benchmark_avion_mcq(n_samples, gpt_model):
43+
def benchmark_avion_mcq(n_samples, gpt_model, action_representation, benchmark_testing = True, n_frames = 8):
2744

2845
inferencer = GPTInferenceAnnotator(gpt_model,
2946
root,
@@ -39,7 +56,7 @@ def benchmark_avion_mcq(n_samples, gpt_model):
3956
inferencer.multi_process_run(n_samples = n_samples,
4057
offset = 0)
4158

42-
def benchmark_tim_mcq(n_samples, gpt_model):
59+
def benchmark_tim_mcq(n_samples, gpt_model, action_representation, benchmark_testing = True, n_frames = 8):
4360

4461
inferencer = GPTInferenceAnnotator(gpt_model,
4562
root,
@@ -54,7 +71,7 @@ def benchmark_tim_mcq(n_samples, gpt_model):
5471
topk = topk)
5572
inferencer.multi_process_run(n_samples = n_samples, offset = 0)
5673

57-
def benchmark_random_mcq(n_samples, gpt_model):
74+
def benchmark_random_mcq(n_samples, gpt_model, action_representation, benchmark_testing = True, n_frames = 8):
5875
inferencer = GPTInferenceAnnotator(gpt_model,
5976
root,
6077
annotation_file,
@@ -75,18 +92,34 @@ def calcuate_acc_from_jsons(json_folder):
7592
print (file)
7693
preds = json.load(open(file))
7794
correct = 0
95+
something = 0
7896
for k,v in preds.items():
97+
options = v['options']
98+
options = [process_raw_pred(e) for e in options]
99+
100+
#assert v['gt_name'] in options, f"{v['gt_name']} not in {options}"
101+
if v['gt_name'] not in options:
102+
print ('what?', options)
103+
print ('what?', v)
104+
break
105+
79106
if v['gt_name'] == v['chatgpt_answer']:
80107
correct+=1
108+
else:
109+
pass
110+
#print ('wrong prediction! pred: gt', v['chatgpt_answer'] + "," + v['gt_name'])
81111
print ('acc ', correct/len(preds))
112+
print ('gt not in options', something)
82113

83114

84115

85116
if __name__ == '__main__':
86-
# benchmark_avion_mcq(-1, 'gpt-4o-mini-2024-07-18')
87-
# benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18')
88-
# benchmark_random_mcq(-1, 'gpt-4o-mini-2024-07-18')
89-
# benchmark_avion_mcq(-1, 'gpt-4o-2024-08-06')
90-
# benchmark_tim_mcq(-1, 'gpt-4o-2024-08-06')
91-
# benchmark_random_mcq(-1, 'gpt-4o-2024-08-06')
92-
calcuate_acc_from_jsons('gpt_full_benchmark_results')
117+
# benchmark_avion_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
118+
# benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
119+
# benchmark_random_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
120+
# benchmark_avion_mcq(-1, 'gpt-4o-2024-08-06', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
121+
# benchmark_tim_mcq(-1, 'gpt-4o-2024-08-06', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
122+
# benchmark_random_mcq(-1, 'gpt-4o-2024-08-06', 'GT_random_narration', benchmark_testing = True, n_frames = 8)
123+
benchmark_tim_mcq(1, 'gpt-4o-mini-2024-07-18', 'official_key', benchmark_testing = False, n_frames = 16)
124+
#benchmark_tim_mcq(-1, 'gpt-4o-mini-2024-07-18', 'GT_random_narration', benchmark_testing = False, n_frames = 16)
125+
#calcuate_acc_from_jsons('gpt_EK100_results')

llava/action/chatgpt_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ def multi_process_run(self, offset= 0, n_samples = -1, disable_api_calling = Fal
483483
if combined_results and 'mc_' in self.question_type:
484484
calculation = calculate_gpt_accuracy(data = combined_results)
485485

486+
if n_samples == -1:
487+
n_samples = len(self.data)
488+
486489
checkpoint_name = f"{self.gpt_model}_{self.gen_type}_{self.action_representation}_top{self.topk}_{self.clip_length}f_{n_samples}samples.json"
487490

488491
if self.do_visualization:

llava/action/make_visualizations.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,41 @@ def save_visualization(vis_folder, frames, uid):
150150
out_dir = Path(vis_folder)
151151
out_dir.mkdir(parents=True, exist_ok=True)
152152
sub_folder = out_dir / uid
153+
fps = 30
154+
height, width = frames[0].shape[:2]
155+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
156+
video_path = str(sub_folder / f"{uid}.mp4")
157+
video_out = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
153158
sub_folder.mkdir(parents=True, exist_ok=True)
154159
for idx, frame in enumerate(frames):
155160
cv2.imwrite(str(sub_folder / f"{uid}_{idx}.jpg"), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
161+
bgr_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
162+
video_out.write(bgr_frame)
163+
video_out.release()
164+
165+
def visualize_with_uid(uid):
166+
from llava.action.utils import avion_video_loader
167+
val_metadata = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv'
168+
vid_path = '_'.join(uid.split('_')[:2]).replace('-', '/')
169+
start_timestamp, end_timestamp = uid.split('_')[2:]
170+
start_timestamp = float(start_timestamp)
171+
end_timestamp = float(end_timestamp)
172+
print (vid_path, start_timestamp, end_timestamp)
173+
# split uid to video path and start, end second
174+
frames, time_meta = avion_video_loader(root,
175+
vid_path,
176+
'MP4',
177+
start_timestamp,
178+
end_timestamp,
179+
chunk_len = 15,
180+
clip_length = n_frames,
181+
threads = 1,
182+
fast_rrc=False,
183+
fast_rcc = False,
184+
jitter = False)
185+
186+
vis_folder = f"figure1_vis"
187+
save_visualization(vis_folder, frames, uid)
156188

157189
def visualize_with_llava(pretrained_path, uid, question_type, gen_type):
158190
"""
@@ -216,7 +248,9 @@ def visualize_with_llava(pretrained_path, uid, question_type, gen_type):
216248

217249
#visualize_with_gpt_with_avion(10, offset = 100, question_type = "caption")
218250
#llava_pretrained_path = 'lmms-lab/LLaVA-Video-7B-Qwen2'
219-
llava_pretrained_path = 'experiments/LLaVA-Video-7B-Qwen2'
220-
uid = 'P01-P01_11_182.65_192.07'
221-
visualize_with_llava(llava_pretrained_path, uid, 'caption', 'tim')
222-
251+
# llava_pretrained_path = 'experiments/LLaVA-Video-7B-Qwen2'
252+
# uid = 'P01-P01_11_182.65_192.07'
253+
# visualize_with_llava(llava_pretrained_path, uid, 'caption', 'tim')
254+
visualize_with_uid("P28-P28_16_73.84_74.66")
255+
visualize_with_uid("P28-P28_15_50.66_51.69")
256+
visualize_with_uid("P26-P26_41_113.0_114.1")

llava/action/utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
from collections import defaultdict
1919
import json
2020
from llava.utils import rank0_print
21-
21+
import re
2222
# set random seed
2323
random.seed(42)
2424

25+
26+
2527
def remove_sub_nouns(nlp, narration, verb, nouns):
2628
narration = copy.deepcopy(narration)
2729
noun_list = ast.literal_eval(nouns)
@@ -433,9 +435,8 @@ def generate_multi_choice(self,
433435
def train_generate(self, gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps, benchmark_testing = False):
434436
# letters as A, B, C, D, .. Note we maximally support 26 letters
435437
letters = [chr(65+i) for i in range(26)][:k]
436-
answer_list = [vn for vn in mapping_vn2narration.keys()]
437438

438-
439+
answer_list = [vn for vn in mapping_vn2narration.keys()]
439440
wrong_answers = np.random.choice(answer_list, size = k-1, replace = False)
440441
answer_ids = [gt_vn] + list(wrong_answers)
441442
random.shuffle(answer_ids)
@@ -456,7 +457,9 @@ def train_generate(self, gt_vn, narration, k, action_representation, n_narration
456457

457458
gt_letter = letters[answer_ids.index(gt_vn)]
458459
gt_answer = answers[answer_ids.index(gt_vn)]
459-
460+
print ('got here')
461+
import sys
462+
sys.exit()
460463
mc_data = {
461464
'options': {0: options},
462465
# the correct letter in mc

0 commit comments

Comments
 (0)