|
13 | 13 | import warnings |
14 | 14 | from decord import VideoReader, cpu |
15 | 15 |
|
16 | | -warnings.filterwarnings("ignore") |
17 | | -# Load the OneVision model |
18 | | -pretrained = "lmms-lab/llava-onevision-qwen2-7b-ov" |
19 | | -model_name = "llava_qwen" |
20 | | -device = "cuda" |
21 | | -device_map = "auto" |
22 | | -tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa") |
| 16 | +def llava_inference(video_frames, gt, logger, num_frames=16, llm_size='7b'): |
| 17 | + |
| 18 | + warnings.filterwarnings("ignore") |
| 19 | + # Load the OneVision model |
| 20 | + pretrained = f"lmms-lab/llava-onevision-qwen2-{llm_size}-ov" |
| 21 | + logger.info(f"Loading model {pretrained}") |
| 22 | + model_name = "llava_qwen" |
| 23 | + device = "cuda" |
| 24 | + device_map = "auto" |
| 25 | + tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa") |
| 26 | + |
| 27 | + model.eval() |
| 28 | + video_frames = video_frames[0] |
23 | 29 |
|
24 | | -model.eval() |
| 30 | + temporal_stride = 16 // num_frames |
25 | 31 |
|
26 | | -def llava_inference(video_frames, gt): |
27 | | - video_frames = video_frames[0] |
| 32 | + video_frames = video_frames[::temporal_stride] |
28 | 33 | image_tensors = [] |
29 | 34 | frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].half().cuda() |
30 | 35 | image_tensors.append(frames) |
|
0 commit comments