Skip to content

Commit cbac23f

Browse files
committed
updated llava inference
1 parent d6191f5 commit cbac23f

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

action/llava_ov_inference.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,23 @@
1313
import warnings
1414
from decord import VideoReader, cpu
1515

16-
warnings.filterwarnings("ignore")
17-
# Load the OneVision model
18-
pretrained = "lmms-lab/llava-onevision-qwen2-7b-ov"
19-
model_name = "llava_qwen"
20-
device = "cuda"
21-
device_map = "auto"
22-
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
16+
def llava_inference(video_frames, gt, logger, num_frames=16, llm_size='7b'):
17+
18+
warnings.filterwarnings("ignore")
19+
# Load the OneVision model
20+
pretrained = f"lmms-lab/llava-onevision-qwen2-{llm_size}-ov"
21+
logger.info(f"Loading model {pretrained}")
22+
model_name = "llava_qwen"
23+
device = "cuda"
24+
device_map = "auto"
25+
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
26+
27+
model.eval()
28+
video_frames = video_frames[0]
2329

24-
model.eval()
30+
temporal_stride = 16 // num_frames
2531

26-
def llava_inference(video_frames, gt):
27-
video_frames = video_frames[0]
32+
video_frames = video_frames[::temporal_stride]
2833
image_tensors = []
2934
frames = image_processor.preprocess(video_frames, return_tensors="pt")["pixel_values"].half().cuda()
3035
image_tensors.append(frames)

0 commit comments

Comments
 (0)