@@ -85,10 +85,13 @@ def video_loader(root, vid, ext, second, end_second,
8585 except decord .DECORDError as error :
8686 print (error )
8787 frames = vr .get_batch ([0 ] * len (frame_ids )).asnumpy ()
88-
88+
8989 return torch .from_numpy (frames .astype (np .float32 ))
9090
9191 else :
92+ time_meta = {}
93+
94+ time_meta ['duration' ] = end_second - second
9295 chunk_start = int (second ) // chunk_len * chunk_len
9396 chunk_end = int (end_second ) // chunk_len * chunk_len
9497 while True :
@@ -109,6 +112,7 @@ def video_loader(root, vid, ext, second, end_second,
109112 num_segments = clip_length , jitter = jitter
110113 )
111114 all_frames = []
115+ all_frame_ids = []
112116 # allocate absolute frame-ids into the relative ones
113117 for chunk in range (chunk_start , chunk_end + chunk_len , chunk_len ):
114118 rel_frame_ids = list (filter (lambda x : int (chunk * fps ) <= x < int ((chunk + chunk_len ) * fps ), frame_ids ))
@@ -127,11 +131,17 @@ def video_loader(root, vid, ext, second, end_second,
127131 except IndexError :
128132 print (root , vid , ext , second , end_second )
129133 all_frames .append (frames )
134+ all_frame_ids .append (frame_ids )
130135 if sum (map (lambda x : x .shape [0 ], all_frames )) == clip_length :
131136 break
132137 res = torch .from_numpy (np .concatenate (all_frames , axis = 0 ).astype (np .float32 ))
138+ time_meta ['n_frames' ] = res .shape [0 ]
139+ all_frame_ids = np .concatenate (all_frame_ids , axis = 0 )
140+ frame_time = [e / fps for e in all_frame_ids ]
141+ frame_time = [f"{ i :.2f} s" for i in frame_time ]
142+ time_meta ['frame_time' ] = frame_time
133143 assert res .shape [0 ] == clip_length , "{}, {}, {}, {}, {}, {}, {}" .format (root , vid , second , end_second , res .shape [0 ], rel_frame_ids , frame_ids )
134- return res
144+ return res , time_meta
135145
136146
137147class VideoCaptionDatasetBase (torch .utils .data .Dataset ):
@@ -194,53 +204,11 @@ def get_raw_item(
194204 fast_rrc = False , rrc_params = (224 , (0.5 , 1.0 )),
195205 fast_rcc = False , rcc_params = (224 ,),
196206 ):
197- if self .dataset == 'ego4d' :
198- vid , start_second , end_second , narration = self .samples [i ][:4 ]
199- frames = video_loader (self .root , vid , 'mp4' ,
200- start_second , end_second ,
201- chunk_len = chunk_len ,
202- clip_length = clip_length ,
203- threads = threads ,
204- fast_rrc = fast_rrc ,
205- rrc_params = rrc_params ,
206- fast_rcc = fast_rcc ,
207- rcc_params = rcc_params ,
208- jitter = is_training )
209- if isinstance (narration , list ):
210- if narration_selection == 'random' :
211- narration = random .choice (narration )
212- elif narration_selection == 'concat' :
213- narration = '. ' .join (narration )
214- elif narration_selection == 'list' :
215- pass
216- else :
217- raise ValueError
218- return frames , narration
219- elif self .dataset == 'ek100_mir' :
220- vid_path , start_second , end_second , fps , narration , verb , noun = self .samples [i ]
221- frames = video_loader (self .root , vid_path , 'MP4' ,
222- start_second , end_second ,
223- chunk_len = chunk_len , fps = fps ,
224- clip_length = clip_length ,
225- threads = threads ,
226- fast_rrc = fast_rrc ,
227- rrc_params = rrc_params ,
228- fast_rcc = fast_rcc ,
229- rcc_params = rcc_params ,
230- jitter = is_training )
231- if is_training :
232- positive_list = np .where (self .relevancy_mat [i ] > self .relevancy )[0 ].tolist ()
233- if positive_list != []:
234- pos = random .sample (positive_list , min (len (positive_list ), 1 ))[0 ]
235- if pos < len (self .metadata_sentence ) and pos < self .relevancy_mat .shape [1 ]:
236- return frames , (self .metadata_sentence .iloc [pos ][1 ], self .relevancy_mat [i ][pos ])
237- else :
238- return frames , (narration , 1 )
239- elif self .dataset == 'ek100_cls' :
207+ if self .dataset == 'ek100_cls' :
240208 vid_path , start_second , end_second , fps , narration , verb , noun = self .samples [i ]
241209 # chunk length is the chunked video clip length
242210 # clip length is number of frames we want to sample from the clip
243- frames = video_loader (self .root , vid_path , 'MP4' ,
211+ frames , time_meta = video_loader (self .root , vid_path , 'MP4' ,
244212 start_second , end_second ,
245213 chunk_len = chunk_len , fps = fps ,
246214 clip_length = clip_length ,
@@ -250,7 +218,7 @@ def get_raw_item(
250218 fast_rcc = fast_rcc ,
251219 rcc_params = rcc_params ,
252220 jitter = is_training )
253- return frames , '{}:{}' .format (verb , noun )
221+ return frames , '{}:{}' .format (verb , noun ), time_meta
254222 else :
255223 raise NotImplementedError
256224
@@ -303,7 +271,7 @@ def __init__(
303271 self .mc_generator = MultiChoiceGenerator (self .ann_root )
304272
305273 def __getitem__ (self , i ):
306- frames , label = self .get_raw_item (
274+ frames , label , time_meta = self .get_raw_item (
307275 i , is_training = self .is_training ,
308276 chunk_len = self .chunk_len ,
309277 num_clips = self .num_clips ,
@@ -317,13 +285,15 @@ def __getitem__(self, i):
317285 sparse_sample = self .sparse_sample ,
318286 )
319287
288+ # for llava-video to work, we also need time meta data.
289+
320290 # apply transformation
321291 if self .transform is not None :
322292 frames = self .transform (frames )
323293
324294 data = self .mc_generator .generate_multi_choice (label , self .topk_predictions )
325295
326- return frames , data
296+ return frames , data , time_meta
327297
328298
329299
@@ -350,7 +320,7 @@ def get_args_parser():
350320
351321 # llava related
352322 # llm size is type of string and can only be '7b' or '5b' etc.
353- parser .add_argument ('--llm_size ' , default = '7b ' , type = str , help = 'llm size ' )
323+ parser .add_argument ('--pretrained_name ' , default = ' ' , type = str , help = 'the name in huggingface ' )
354324 parser .add_argument ('--llava_num_frames' , default = 16 , type = int , help = 'number of frames for llava' )
355325 ## avaion refinement
356326 parser .add_argument ('--action_predictions' , default = None , type = str , help = 'path to action predictions' )
@@ -362,13 +332,15 @@ def get_args_parser():
362332def prepare_llava (pretrained ):
363333
364334 import warnings
365- from llava .model .builder import load_pretrained_model
366335 warnings .filterwarnings ("ignore" )
367- # Load the OneVision model
336+ from llava . model . builder import load_pretrained_model
368337 model_name = "llava_qwen"
369338
370339 device_map = "auto"
371- tokenizer , model , image_processor , max_length = load_pretrained_model (pretrained , None , model_name , device_map = device_map , attn_implementation = "sdpa" )
340+ print ('pretrained???' , pretrained )
341+ #tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
342+ tokenizer , model , image_processor , max_length = load_pretrained_model (pretrained , None , model_name , torch_dtype = "bfloat16" , device_map = device_map ) # Add any other thing you want to pass in llava_model_args
343+
372344
373345 return tokenizer , model , image_processor , max_length
374346
@@ -392,7 +364,9 @@ def get_topk_predictions(data, idx, k):
392364
393365 return mc_data
394366
395- def ensemble_llava_evaluation (gt_name ,
367+ def ensemble_llava_evaluation (
368+ pretrained_name ,
369+ gt_name ,
396370 frames ,
397371 tokenizer ,
398372 model ,
@@ -402,6 +376,7 @@ def ensemble_llava_evaluation(gt_name,
402376 num_frames ,
403377 temperature = 0 ,
404378 ensemble_k = 1 ,
379+ time_meta = None ,
405380 is_test = False
406381 ):
407382 """
@@ -424,20 +399,24 @@ def ensemble_llava_evaluation(gt_name,
424399 rank0_print ('generated new option sequence' )
425400 rank0_print (options )
426401
427- pred = llava_inference (frames ,
428- tokenizer ,
429- model ,
430- image_processor ,
431- mc_data ,
432- clip_length = clip_length ,
433- num_frames = num_frames ,
434- temperature = temperature ,
435- is_test = is_test
402+ pred = llava_inference (
403+ pretrained_name ,
404+ frames ,
405+ tokenizer ,
406+ model ,
407+ image_processor ,
408+ mc_data ,
409+ clip_length = clip_length ,
410+ num_frames = num_frames ,
411+ temperature = temperature ,
412+ is_test = is_test ,
413+ time_meta = time_meta
436414 )
437415
438416 rank0_print ('llava pred' , pred , 'avion_pred' , avion_pred , 'gt_name' , gt_name )
439- sep = pred .index ('.' )
440- pred = pred [sep + 1 :].strip ()
417+ if '.' in pred :
418+ sep = pred .index ('.' )
419+ pred = pred [sep + 1 :].strip ()
441420 preds .append (pred )
442421
443422 counter = Counter (preds )
@@ -482,14 +461,9 @@ def evaluate_on_EK100(eval_args,
482461
483462 running_corrects = 0
484463 total_samples = 0
485-
486- if not eval_args .action_predictions :
487- log_filename = f'llava_ov_{ eval_args .llava_num_frames } f_{ eval_args .llm_size } .log'
488- else :
489- log_filename = f'llava_ov_{ eval_args .llava_num_frames } f_{ eval_args .llm_size } _action_{ eval_args .topk_predictions } .log'
490-
464+
491465 # Set up logging
492- logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' , filename = log_filename , filemode = 'w' )
466+ logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' , filemode = 'w' )
493467
494468 console_handler = logging .StreamHandler (sys .stdout )
495469 console_handler .setLevel (logging .INFO )
@@ -502,23 +476,24 @@ def evaluate_on_EK100(eval_args,
502476
503477 logger = logging .getLogger (__name__ )
504478
505- pretrained = f"lmms-lab/llava-onevision-qwen2-{ eval_args .llm_size } -ov"
479+ pretrained = f"lmms-lab/{ eval_args .pretrained_name } " .strip ()
480+ print ('pretrained' , pretrained )
506481
507482 # so we know it's evaluation during training
508483 finish_early = model is not None
509484
510485 if model is None :
511- if hasattr ( eval_args , " llava_checkpoint" ) :
486+ if args . llava_checkpoint is not None :
512487 pretrained = eval_args .llava_checkpoint
513- tokenizer , model , image_processor , max_length = prepare_llava (pretrained )
488+ tokenizer , model , image_processor , _ = prepare_llava (pretrained )
514489
515490 if eval_args .action_predictions :
516491 with open (eval_args .action_predictions , 'r' ) as f :
517492 predictions = json .load (f )
518493
519494 avaion_correct = 0
520495
521- for idx , (frames , mc_data ) in tqdm (enumerate (val_dataloader )):
496+ for idx , (frames , mc_data , time_meta ) in tqdm (enumerate (val_dataloader )):
522497
523498 gt_name = mc_data ['gt_answer_name' ][0 ][0 ]
524499
@@ -531,24 +506,22 @@ def evaluate_on_EK100(eval_args,
531506 # we don't want to evaluate the whole thing
532507 # let's evaluate 1000 samples to get the complete picture
533508 if finish_early and idx > 999 :
534- break
535-
536- # pred = llava_inference(frames, tokenizer, model, image_processor, mc_data, clip_length = eval_args.clip_length, num_frames=eval_args.llava_num_frames)
537-
538- # # if valid letter is found in the prediction, then we will use that as the prediction
539- # rank0_print ('llava pred', pred, 'avion_pred', avion_pred, 'gt_name', gt_name)
509+ break
540510
541511 # Update running corrects and total samples
542- running_corrects += ensemble_llava_evaluation (gt_name ,
512+ running_corrects += ensemble_llava_evaluation (
513+ eval_args .pretrained_name ,
514+ gt_name ,
543515 frames ,
544516 tokenizer ,
545517 model ,
546518 image_processor ,
547519 mc_data ,
548520 eval_args .clip_length ,
549521 eval_args .llava_num_frames ,
550- temperature = 2.0 ,
551- ensemble_k = 5 ,
522+ temperature = 0 ,
523+ ensemble_k = 1 ,
524+ time_meta = time_meta ,
552525 is_test = not finish_early )
553526
554527 total_samples += 1
0 commit comments