Skip to content

Commit c52468c

Browse files
author
Ye Shaokai
committed
llava vis also works
1 parent c6a1a25 commit c52468c

File tree

4 files changed

+129
-10
lines changed

4 files changed

+129
-10
lines changed

llava/action/ek_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import json
1212
import logging
1313
from llava.utils import rank0_print
14-
from llava.action.utils import generate_label_map, match_answer
14+
from llava.action.utils import generate_label_map
1515
from collections import Counter
1616
import torch.distributed as dist
1717
from llava.action.dataset import VideoMultiChoiceDataset

llava/action/llava_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def llava_inference(
5252
elif test_type == "direct_narration":
5353
question_type = "direct_narration"
5454
elif test_type == 'caption' or test_type == 'debug':
55-
question_type = "gpt-gt-reason"
55+
question_type = "caption"
5656
elif test_type == 'temporal_cot':
5757
question_type = 'temporal_cot'
5858

llava/action/make_visualizations.py

Lines changed: 126 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
1010
Note that in each inference, we should be able to pick the corresponding prompt and checkpoint folder
1111
"""
12-
12+
from llava.action.utils import generate_label_map
1313
from llava.action.chatgpt_utils import GPTInferenceAnnotator
14-
14+
from pathlib import Path
15+
from llava.action.utils import AvionMultiChoiceGenerator as ActionMultiChoiceGenerator
16+
from llava.action.llava_inference import llava_inference
17+
import json
18+
import cv2
1519
# root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
1620
# annotation_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
1721
# avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
@@ -31,7 +35,6 @@
3135
benchmark_testing = True
3236

3337

34-
3538
def visualize_with_random(n_samples, offset = 0, question_type = 'mc_'):
3639
"""
3740
Here we should test gpt-4o, gpt-4o-mini with different prompts
@@ -75,6 +78,7 @@ def visualize_with_gpt_with_avion(n_samples, offset = 0, question_type = 'mc_'):
7578
"""
7679
Here we should test gpt-4o, gpt-4o-mini with different prompts
7780
"""
81+
7882
inferencer = GPTInferenceAnnotator(gpt_model,
7983
root,
8084
annotation_file,
@@ -91,13 +95,128 @@ def visualize_with_gpt_with_avion(n_samples, offset = 0, question_type = 'mc_'):
9195
inferencer.multi_process_run(n_samples = n_samples, offset = offset, disable_api_calling=False)
9296

9397

94-
def visualize_with_llava(uid, ):
95-
"""
98+
def search_option_data_by_uid(uid, anno_file, gen_type = 'tim'):
99+
import csv
100+
from llava.action.dataset import datetime2sec
101+
csv_reader = csv.reader(open(anno_file, 'r'))
102+
_ = next(csv_reader) # skip the header
103+
query_vid_path = '_'.join(uid.split('_')[:2]).replace('-', '/')
104+
query_start_timestamp, query_end_timestamp = uid.split('_')[2:]
105+
anno_root = Path(anno_file).parent
106+
labels, mapping_vn2narration, mapping_vn2act, verb_maps, noun_maps = generate_label_map(anno_root,
107+
action_representation)
108+
with open(tim_prediction_file, 'r') as f:
109+
action_model_predictions = json.load(f)
110+
mc_generator = ActionMultiChoiceGenerator(anno_root)
96111

97-
"""
112+
for idx, row in enumerate(csv_reader):
113+
pid, vid = row[1:3]
114+
start_second, end_second = datetime2sec(row[4]), datetime2sec(row[5])
115+
start_second = round(float(start_second),2)
116+
end_second = round(float(end_second),2)
117+
vid_path = '{}/{}'.format(pid, vid)
118+
verb, noun = int(row[10]), int(row[12])
119+
gt_vn = '{}:{}'.format(verb, noun)
120+
narration = row[8]
98121

122+
if query_vid_path!=vid_path and start_second!=query_start_timestamp and end_second!=query_end_timestamp:
123+
continue
124+
125+
if gen_type == 'avion' or gen_type == 'tim':
126+
action_preds = action_model_predictions[str(idx)]['predictions']
127+
mc_data =mc_generator.generate_multi_choice(gt_vn,
128+
action_preds,
129+
narration,
130+
topk,
131+
action_representation,
132+
-1, # n_narrations
133+
labels,
134+
mapping_vn2narration,
135+
verb_maps,
136+
noun_maps,
137+
benchmark_testing = benchmark_testing,
138+
is_train = False)
139+
140+
options = mc_data['options'][0]
141+
return {
142+
'options': options,
143+
'narration': narration,
144+
'start_second': start_second,
145+
'end_second': end_second,
146+
'gt_answer': gt_vn
147+
}
148+
149+
def save_visualization(vis_folder, frames, uid):
150+
out_dir = Path(vis_folder)
151+
out_dir.mkdir(parents=True, exist_ok=True)
152+
sub_folder = out_dir / uid
153+
sub_folder.mkdir(parents=True, exist_ok=True)
154+
for idx, frame in enumerate(frames):
155+
cv2.imwrite(str(sub_folder / f"{uid}_{idx}.jpg"), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
99156

157+
def visualize_with_llava(pretrained_path, uid, question_type, gen_type):
158+
"""
159+
"""
160+
from llava.action.ek_eval import prepare_llava
161+
from llava.action.dataset import VideoMultiChoiceDataset
162+
163+
import torch
164+
165+
from llava.action.utils import avion_video_loader
166+
val_metadata = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv'
167+
168+
gpu_val_transform_ls = []
100169

170+
val_transform_gpu = torch.nn.Sequential(*gpu_val_transform_ls)
171+
172+
vid_path = '_'.join(uid.split('_')[:2]).replace('-', '/')
173+
start_timestamp, end_timestamp = uid.split('_')[2:]
174+
start_timestamp = float(start_timestamp)
175+
end_timestamp = float(end_timestamp)
176+
print (vid_path, start_timestamp, end_timestamp)
177+
# split uid to video path and start, end second
178+
frames, time_meta = avion_video_loader(root,
179+
vid_path,
180+
'MP4',
181+
start_timestamp,
182+
end_timestamp,
183+
chunk_len = 15,
184+
clip_length = n_frames,
185+
threads = 1,
186+
fast_rrc=False,
187+
fast_rcc = False,
188+
jitter = False)
189+
190+
vis_folder = f"{gpt_model}_{gen_type}_{question_type}_{perspective}"
191+
save_visualization(vis_folder, frames, uid)
192+
193+
options = search_option_data_by_uid(uid, val_metadata, gen_type = gen_type)
194+
195+
print (options)
196+
mc_data = options
197+
tokenizer, model, image_processor, _ = prepare_llava(pretrained_path)
198+
pred = llava_inference(
199+
[frames],
200+
tokenizer,
201+
model,
202+
image_processor,
203+
mc_data,
204+
test_type = question_type,
205+
clip_length = n_frames,
206+
num_frames=n_frames,
207+
temperature = 0,
208+
time_meta = time_meta,
209+
learn_neighbor_actions = False,
210+
meta_data = None,
211+
perspective = perspective
212+
)
213+
214+
print (pred)
101215
if __name__ == '__main__':
102216

103-
visualize_with_gpt_with_avion(10, offset = 100, question_type = "caption")
217+
#visualize_with_gpt_with_avion(10, offset = 100, question_type = "caption")
218+
llava_pretrained_path = 'lmms-lab/LLaVA-Video-7B-Qwen2'
219+
llava_pretrained_path = 'experiments/LLaVA-Video-7B-Qwen2'
220+
uid = 'P01-P01_11_34.38_41.15'
221+
visualize_with_llava(llava_pretrained_path, uid, 'caption', 'tim')
222+

llava/action/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def avion_video_loader(root, vid, ext, second, end_second,
651651
chunk_start = int(second) // chunk_len * chunk_len
652652
chunk_end = int(end_second) // chunk_len * chunk_len
653653
while True:
654-
video_filename = osp.join(root, '{}.{}'.format(vid, ext), '{}.{}'.format(chunk_end, ext))
654+
video_filename = osp.join(root, '{}.{}'.format(vid, ext), '{}.{}'.format(chunk_end, ext))
655655
if not osp.exists(video_filename):
656656
# print("{} does not exists!".format(video_filename))
657657
chunk_end -= chunk_len

0 commit comments

Comments
 (0)