Skip to content

Commit 8891c50

Browse files
author
Ye Shaokai
committed
Merge branch 'visualization' into shaokai/dev
2 parents 46ea7d6 + 126392b commit 8891c50

File tree

6 files changed

+287
-33
lines changed

6 files changed

+287
-33
lines changed

llava/action/benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def benchmark_tim_mcq(n_samples):
3737
root,
3838
annotation_file,
3939
gen_type = 'tim',
40-
prediction_file = avion_prediction_file,
40+
prediction_file = tim_prediction_file,
4141
clip_length = n_frames,
4242
question_type = 'mc_',
4343
action_representation=action_representation,
@@ -63,6 +63,6 @@ def benchmark_random_mcq(n_samples):
6363

6464

6565
if __name__ == '__main__':
66-
benchmark_avion_mcq(100)
66+
#benchmark_avion_mcq(100)
6767
benchmark_tim_mcq(100)
6868
#benchmark_random_mcq(100)

llava/action/chatgpt_utils.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import base64
2020
from pathlib import Path
2121
import traceback
22+
import cv2
2223

2324

2425
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
@@ -348,7 +349,8 @@ def __init__(self,
348349
debug = False,
349350
topk = 10,
350351
perspective = 'first_person',
351-
benchmark_testing = False
352+
benchmark_testing = False,
353+
do_visualization = False
352354
):
353355
"""
354356
Parameters
@@ -373,17 +375,31 @@ def __init__(self,
373375
self.perspective = perspective
374376
self.benchmark_testing = benchmark_testing
375377
assert gen_type in ['avion', 'tim', 'random']
376-
378+
377379
if gen_type == 'avion' or gen_type == 'tim':
378380
self.mc_generator = ActionMultiChoiceGenerator(self.annotation_root)
381+
assert os.path.exists(self.prediction_file)
379382
with open(self.prediction_file, 'r') as f:
380383
self.action_model_predictions = json.load(f)
381384
else:
382385
self.mc_generator = RandomMultiChoiceGenerator(self.annotation_root)
383386

384-
387+
self.do_visualization = do_visualization
388+
self.vis_folder = f"{self.gpt_model}_{self.gen_type}_{self.question_type}_{self.perspective}"
389+
os.makedirs(self.vis_folder, exist_ok = True)
385390
self.data = self.init_data()
386-
391+
392+
def save_visualization(self,frames, uid):
393+
"""
394+
Save the frames to the out_dir
395+
"""
396+
out_dir = Path(self.vis_folder)
397+
out_dir.mkdir(parents=True, exist_ok=True)
398+
sub_folder = out_dir / uid
399+
sub_folder.mkdir(parents=True, exist_ok=True)
400+
for idx, frame in enumerate(frames):
401+
cv2.imwrite(str(sub_folder / f"{uid}_{idx}.jpg"), cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
402+
387403

388404
def init_data(self):
389405
ret = {}
@@ -435,41 +451,45 @@ def init_data(self):
435451
'end_second': end_second,
436452
'vid_path': vid_path
437453
}
438-
439454
return ret
440455

441-
def multi_process_run(self, n_samples = -1):
442-
# to initialize it
456+
def multi_process_run(self, offset= 0, n_samples = -1, disable_api_calling = False):
457+
# inside GPT inference annotator
443458

444-
if n_samples != -1:
445-
indices = list(range(len(self.data)))[:n_samples]
459+
if n_samples == -1:
460+
# do not use offset if n_samples is -1
461+
assert offset == 0
446462

463+
if n_samples != -1:
464+
indices = list(range(len(self.data)))[offset:offset + n_samples]
447465
num_chunks = os.cpu_count() if not self.debug else 2
448466

449467
indices_groups = self.split_indices(indices, num_chunks)
450468

451469
with ProcessPoolExecutor(max_workers=num_chunks) as executor:
452470
# Pass additional arguments to the function
453-
futures = [executor.submit(self.run, group) for group in indices_groups]
471+
futures = [executor.submit(self.run, group, disable_api_calling) for group in indices_groups]
454472

455473
# Wait for all futures to complete
456474
combined_results = {}
457475
for future in futures:
458476
result_dict = future.result()
459477
combined_results.update(result_dict)
460-
478+
print (combined_results)
461479
if self.debug:
462480
print (combined_results)
463-
464-
calculation = calculate_gpt_accuracy(data = combined_results)
481+
if combined_results and 'mc_' in self.question_type:
482+
calculation = calculate_gpt_accuracy(data = combined_results)
465483

466484
prefix = self.gen_type
467485
assert n_samples != -1
468486
checkpoint_name = f"{prefix}_{self.action_representation}_top{self.topk}_{self.clip_length}f_{n_samples}samples.json"
469487

488+
if self.do_visualization:
489+
self.checkpoint(combined_results, os.path.join(self.vis_folder, checkpoint_name))
470490
self.checkpoint(combined_results, checkpoint_name)
471491

472-
def run(self, indices=None):
492+
def run(self, indices=None, disable_api_calling = False):
473493
if indices is None:
474494
data_batch = {i : self.data[i] for i in range(len(self.data)) if i in list(range(len(self.data)))}
475495
else:
@@ -481,22 +501,36 @@ def run(self, indices=None):
481501
start_timestamp = v['start_second']
482502
end_timestamp = v['end_second']
483503
vid_path = v['vid_path']
504+
_id = v['vid_path'].replace('/', '-')
505+
uid = f"{_id}_{start_timestamp}_{end_timestamp}"
484506

485507
frames, time_meta = self.extract_frames(vid_path, start_timestamp, end_timestamp)
486-
try:
508+
509+
if self.do_visualization:
510+
# the output folder should reflect the gen type, question type and perspective
511+
# and the question type
512+
self.save_visualization(frames, uid)
513+
if disable_api_calling:
514+
break
515+
try:
487516
parsed_answer = self.predict_images(frames, v)
488517
except Exception as e:
489518
# get full stack trace
490-
traceback.print_exc()
491-
519+
traceback.print_exc()
492520
print ("An exception occurred: ", e)
493521

494522
predicted_answer = parsed_answer.answer
495523
gt_name = v['gt_answer']
496524
ret[k] = {
525+
"uid": uid,
497526
'gt_name': gt_name,
498-
'chatgpt_answer': process_raw_pred(predicted_answer),
527+
"options": v['options'],
528+
'chatgpt_answer': process_raw_pred(predicted_answer) if 'mc_' in self.question_type else predicted_answer
499529
}
530+
if self.do_visualization:
531+
# save ret to the output folder
532+
self.checkpoint(ret, os.path.join(self.vis_folder, uid, 'inference_results.json'))
533+
500534
if self.debug:
501535
break
502536

@@ -529,9 +563,7 @@ def predict_images(self, images, parsed_item):
529563

530564
if 'o1' in self.gpt_model:
531565
system_prompt += format_prompt
532-
533-
#print (system_prompt)
534-
566+
535567
if self.handobj_root is not None:
536568
system_prompt += f"""To further assist you, we mark hands and object when they are visible. The left hand is marked with a bounding box that contains letter L and the right hand's bounding box contains letter R. The object is marked as 'O'."""
537569

llava/action/ek_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import json
1212
import logging
1313
from llava.utils import rank0_print
14-
from llava.action.utils import generate_label_map, match_answer
14+
from llava.action.utils import generate_label_map
1515
from collections import Counter
1616
import torch.distributed as dist
17-
from llava.action.dataset import VideoMultiChoiceDataset, VideoTemporalMultiChoiceDataset
17+
from llava.action.dataset import VideoMultiChoiceDataset
1818
import torchvision.io as io
1919
import re
2020

llava/action/llava_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def llava_inference(
5252
elif test_type == "direct_narration":
5353
question_type = "direct_narration"
5454
elif test_type == 'caption' or test_type == 'debug':
55-
question_type = "gpt-gt-reason"
55+
question_type = "caption"
5656
elif test_type == 'temporal_cot':
5757
question_type = 'temporal_cot'
5858

0 commit comments

Comments
 (0)