44from llava .action .ek_eval import prepare_llava
55from llava .action .generate_interval_pred import get_lookup_dict
66from llava .action .llava_inference import llava_inference
7+ from llava .action .utils import avion_video_loader
78
89from llava .constants import IMAGE_TOKEN_INDEX , DEFAULT_IMAGE_TOKEN
910# val_metadata = '/data/shaokai/epic-kitchens-100-annotations/EPIC_100_validation.csv'
2021
2122
2223def get_frames_by_uid (uid , root ):
23- from llava .action .utils import avion_video_loader
2424 vid_path = '_' .join (uid .split ('_' )[:2 ]).replace ('-' , '/' )
25+ print ('debug' , uid )
2526 start_timestamp , end_timestamp = uid .split ('_' )[2 :]
2627 start_timestamp = float (start_timestamp )
2728 end_timestamp = float (end_timestamp )
@@ -51,11 +52,11 @@ def get_meta_data():
5152 pass
5253
5354
54- def inference_task_by_uid (question , checkpoint_folder , uid , task ):
55+ def inference_task_by_uid (data_root , question , checkpoint_folder , uid , task ):
5556
5657 tokenizer , model , image_processor , max_length = prepare_llava (checkpoint_folder )
5758
58- frames , time_meta = get_frames_by_uid (uid , root )
59+ frames , time_meta = get_frames_by_uid (uid , data_root )
5960
6061 meta_data = None
6162 learn_neighbor_actions = ""
@@ -86,15 +87,56 @@ def inference_task_by_uid(question, checkpoint_folder, uid, task):
8687 perspective = perspective ,
8788 include_time_instruction = include_time_instruction
8889 )
89- print (pred )
90+ return pred
91+
92+ class SelectiveInferencer :
93+ def __init__ (self , data_root , checkpoint_folder , include_time_instruction = False , n_frames = 32 ):
94+ self .data_root = data_root
95+ self .checkpoint_folder = checkpoint_folder
96+ self .tokenizer , self .model , self .image_processor , self .max_length = prepare_llava (checkpoint_folder )
97+ self .include_time_instruction = include_time_instruction
98+ self .n_frames = n_frames
99+ def inference (self , question , uid , task ):
100+ frames , time_meta = get_frames_by_uid (uid , self .data_root )
101+
102+ meta_data = None
103+ learn_neighbor_actions = ""
104+ if 'temporal_cot' in task :
105+ lookup_table = get_lookup_dict (val_metadata ,
106+ action_representation ,
107+ test_type = task ,
108+ pseudo_folder = '' )
109+ meta_data = lookup_table .get (uid , None )
110+ learn_neighbor_actions = "prior"
111+
112+
113+ pred = llava_inference (
114+ [frames ],
115+ self .tokenizer ,
116+ self .model ,
117+ self .image_processor ,
118+ question ,
119+ test_type = task ,
120+ clip_length = self .n_frames ,
121+ num_frames = self .n_frames ,
122+ temperature = 0 ,
123+ time_meta = time_meta ,
124+ learn_neighbor_actions = learn_neighbor_actions ,
125+ meta_data = meta_data ,
126+ perspective = perspective ,
127+ include_time_instruction = self .include_time_instruction
128+ )
129+ return pred
130+
90131
91132if __name__ == '__main__' :
92133 pretrained_model_folder = 'experiments/dev_LLaVA-Video-7B-Qwen2_64f_top5_gpt4o_avion_tim_last_layer_one_token_detection_direct_neighbor_178K_100percent_time'
93134 uid = 'P28-P28_15_50.66_51.69'
94135 task = 'open-ended'
95136 question = "What is the object that is to the left of the knife?"
96137
97- inference_task_by_uid (question ,
138+ inference_task_by_uid (data_root ,
139+ question ,
98140 pretrained_model_folder ,
99141 uid ,
100142 task )
0 commit comments