Skip to content

Commit e9b81f2

Browse files
author
Haozhe Qi
committed
better inference behavior
1 parent 5e791f9 commit e9b81f2

File tree

3 files changed

+80
-42
lines changed

3 files changed

+80
-42
lines changed

llava/action/llava_inference.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def llava_inference(
1414
tokenizer,
1515
model,
1616
image_processor,
17-
mc_data,
17+
input,
1818
clip_length = 16,
1919
num_frames = 16,
2020
temperature = 0,
@@ -44,35 +44,29 @@ def llava_inference(
4444
image_tensors.append(frames)
4545

4646
conv_template = "qwen_1_5"
47-
48-
options = mc_data['options'][0]
47+
original_input = input
48+
if isinstance(input, dict):
49+
input = input['options'][0] if input else None
50+
4951
if test_type == 'base':
5052
question_type = "mc_top5_official_key"
51-
elif test_type == "direct_narration":
52-
question_type = "direct_narration"
53-
elif test_type == 'caption' or test_type == 'debug':
54-
question_type = "caption"
55-
elif test_type == 'temporal_cot_pseudo':
56-
question_type = 'temporal_cot_pseudo'
57-
elif test_type == 'temporal_cot_oracle':
58-
question_type = 'temporal_cot_oracle'
59-
elif test_type == 'temporal_cot_caption':
60-
question_type = 'temporal_cot_caption'
53+
else:
54+
question_type = test_type
6155

6256
if test_type == 'caption_then_answer':
6357
caption_answer = llava_inference([video_frames],
6458
tokenizer,
6559
model,
6660
image_processor,
67-
mc_data,
61+
original_input,
6862
test_type = 'caption',
6963
clip_length = clip_length,
7064
num_frames = num_frames,
7165
temperature = 0,
7266
time_meta = time_meta)
7367

7468
question = format_llava_prompt(DEFAULT_IMAGE_TOKEN,
75-
options,
69+
input,
7670
video_duration,
7771
n_frames,
7872
"mc_top5_official_key",
@@ -85,7 +79,7 @@ def llava_inference(
8579

8680
else:
8781
question = format_llava_prompt(DEFAULT_IMAGE_TOKEN,
88-
options,
82+
input,
8983
video_duration,
9084
n_frames,
9185
question_type,
@@ -102,7 +96,8 @@ def llava_inference(
10296
conv.append_message(conv.roles[0], question)
10397
conv.append_message(conv.roles[1], None)
10498
prompt_question = conv.get_prompt()
105-
99+
print ("what is the question?", question)
100+
106101

107102
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
108103
image_sizes = [frame.size for frame in video_frames]

llava/action/selective_inference.py

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,21 @@
33
"""
44
from llava.action.ek_eval import prepare_llava
55
from llava.action.generate_interval_pred import get_lookup_dict
6-
from llava.action.inference import llava_inference
7-
8-
val_metadata = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv'
9-
root = '/data/shaokai/EK100_512/EK100'
6+
from llava.action.llava_inference import llava_inference
7+
8+
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
9+
# val_metadata = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv'
10+
# root = '/data/shaokai/EK100_512/EK100'
11+
val_metadata = '/iopsstor/scratch/cscs/hqi/VFM/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv'
12+
root = '/iopsstor/scratch/cscs/hqi/VFM/onevision/EK100_512/EK100'
13+
1014
n_frames = 32
1115
action_representation = 'GT_random_narration'
16+
perspective = 'first_person'
17+
include_time_instruction = False
18+
image_token = DEFAULT_IMAGE_TOKEN
19+
20+
1221

1322
def get_frames_by_uid(uid, root):
1423
from llava.action.utils import avion_video_loader
@@ -29,32 +38,63 @@ def get_frames_by_uid(uid, root):
2938
fast_rrc=False,
3039
fast_rcc = False,
3140
jitter = False)
32-
return frames
41+
return frames, time_meta
42+
#
43+
44+
45+
46+
47+
3348

34-
def inference_task_by_uid(checkpoint_folder, uid, task):
49+
# for prior actions
50+
def get_meta_data():
51+
pass
52+
53+
54+
def inference_task_by_uid(question, checkpoint_folder, uid, task):
3555

3656
tokenizer, model, image_processor, max_length = prepare_llava(checkpoint_folder)
3757

38-
frames = get_frames_by_uid(uid, root)
39-
58+
frames, time_meta = get_frames_by_uid(uid, root)
59+
60+
meta_data = None
61+
learn_neighbor_actions = ""
4062
if 'temporal_cot' in task:
41-
get_lookup_dict(val_metadata,
63+
lookup_table = get_lookup_dict(val_metadata,
4264
action_representation,
4365
test_type = task,
4466
pseudo_folder = '')
45-
pred = llava_inference(
46-
frames,
47-
tokenizer,
48-
model,
49-
image_processor,
50-
mc_data,
51-
test_type = test_type,
52-
clip_length = clip_length,
53-
num_frames=num_frames,
54-
temperature = temperature,
55-
time_meta = time_meta,
56-
learn_neighbor_actions = learn_neighbor_actions,
57-
meta_data = meta_data,
58-
perspective = perspective,
59-
include_time_instruction = include_time_instruction
60-
)
67+
meta_data = lookup_table.get(uid, None)
68+
learn_neighbor_actions = "prior"
69+
70+
video_duration = time_meta['duration']
71+
72+
73+
pred = llava_inference(
74+
[frames],
75+
tokenizer,
76+
model,
77+
image_processor,
78+
question,
79+
test_type = task,
80+
clip_length = n_frames,
81+
num_frames= n_frames,
82+
temperature = 0,
83+
time_meta = time_meta,
84+
learn_neighbor_actions = learn_neighbor_actions,
85+
meta_data = meta_data,
86+
perspective = perspective,
87+
include_time_instruction = include_time_instruction
88+
)
89+
print (pred)
90+
91+
if __name__ == '__main__':
92+
pretrained_model_folder = 'experiments/dev_LLaVA-Video-7B-Qwen2_64f_top5_gpt4o_avion_tim_last_layer_one_token_detection_direct_neighbor_178K_100percent_time'
93+
uid = 'P28-P28_15_50.66_51.69'
94+
task = 'open-ended'
95+
question = "What is the object that is to the left of the knife?"
96+
97+
inference_task_by_uid(question,
98+
pretrained_model_folder,
99+
uid,
100+
task)

llava/action/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,9 @@ def format_task_related_prompt(question, question_type, meta_data = None, perspe
291291
elif question_type == "dpo":
292292
ret = "You are seeing this video from egocentric view and you are the person. Your hands are sometimes interacting with obects. Describe in details what you see and what you are doing."
293293

294+
elif question_type == "open-ended":
295+
ret = f"You are seeing this video from egocentric view and you are the person. {question}"
296+
294297
elif question_type == "gpt-gt-instruct-reason":
295298
ret = question
296299
elif question_type == "gpt-hand-object":

0 commit comments

Comments
 (0)