@@ -610,6 +610,21 @@ def get_args_parser():
610610
611611 return parser
612612
613+ def prepare_llava ():
614+
615+ import warnings
616+ from llava .model .builder import load_pretrained_model
617+ warnings .filterwarnings ("ignore" )
618+ # Load the OneVision model
619+ #pretrained = f"lmms-lab/llava-onevision-qwen2-{llm_size}-ov"
620+ model_name = "llava_qwen"
621+
622+ device_map = "auto"
623+ tokenizer , model , image_processor , max_length = load_pretrained_model (pretrained , None , model_name , device_map = device_map , attn_implementation = "sdpa" )
624+
625+ return tokenizer , model , image_processor , max_length
626+
627+
613628if __name__ == '__main__' :
614629 from moviepy .editor import ImageSequenceClip
615630 import torchvision
@@ -640,10 +655,24 @@ def get_args_parser():
640655
641656 # Set up logging
642657 logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' , filename = f'llava_ov_{ args .llava_num_frames } f_{ args .llm_size } .log' , filemode = 'w' )
658+
659+ console_handler = logging .StreamHandler (sys .stdout )
660+ console_handler .setLevel (logging .INFO )
661+
662+ # Set the same format for console handler as well
663+ console_handler .setFormatter (logging .Formatter ('%(asctime)s - %(levelname)s - %(message)s' ))
664+
665+ # Add the console handler to the root logger
666+ logging .getLogger ().addHandler (console_handler )
667+
643668 logger = logging .getLogger (__name__ )
644669
670+ pretrained = f"lmms-lab/llava-onevision-qwen2-{ args .llm_size } -ov"
671+
672+ tokenizer , model , image_processor , max_length = prepare_llava ()
673+
645674 for idx , (frames , gt ) in tqdm (enumerate (val_dataloader )):
646- pred = llava_inference (frames , gt , logger , num_frames = args . llava_num_frames , llm_size = args .llm_size )
675+ pred = llava_inference (frames , tokenizer , model , image_processor , max_length , gt , num_frames = args .llava_num_frames )
647676
648677 # if valid letter is found in the prediction, then we will use that as the prediction
649678 found = False
@@ -671,4 +700,4 @@ def get_args_parser():
671700 # get final accuracy
672701 accuracy = np .mean (gts == preds )
673702 logger .info (f'Final accuracy: { accuracy :.4f} ' )
674-
703+
0 commit comments