Skip to content

Commit 789531b

Browse files
authored
Merge pull request #152 from LLaVA-VL/fix/onevision_tut
Provide the correct video processing logic with decord
2 parents 3fbf54b + 742235b commit 789531b

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

docs/LLaVA_OneVision_Tutorials.ipynb

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@
237237
"metadata": {},
238238
"outputs": [],
239239
"source": [
240+
"from operator import attrgetter\n",
240241
"from llava.model.builder import load_pretrained_model\n",
241242
"from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token\n",
242243
"from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX\n",
@@ -249,41 +250,39 @@
249250
"import requests\n",
250251
"import copy\n",
251252
"import warnings\n",
253+
"from decord import VideoReader, cpu\n",
252254
"\n",
253255
"warnings.filterwarnings(\"ignore\")\n",
254256
"# Load the OneVision model\n",
255257
"pretrained = \"lmms-lab/llava-onevision-qwen2-0.5b-ov\"\n",
256258
"model_name = \"llava_qwen\"\n",
257259
"device = \"cuda\"\n",
258260
"device_map = \"auto\"\n",
259-
"tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map)\n",
261+
"tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation=\"sdpa\")\n",
260262
"\n",
261263
"model.eval()\n",
262264
"\n",
263265
"\n",
264266
"# Function to extract frames from video\n",
265-
"def extract_frames(video_path, num_frames=8):\n",
266-
" cap = cv2.VideoCapture(video_path)\n",
267-
" frames = []\n",
268-
" total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
269-
" indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)\n",
270-
"\n",
271-
" for i in indices:\n",
272-
" cap.set(cv2.CAP_PROP_POS_FRAMES, i)\n",
273-
" ret, frame = cap.read()\n",
274-
" if ret:\n",
275-
" frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
276-
" frames.append(Image.fromarray(frame))\n",
277-
"\n",
278-
" cap.release()\n",
279-
" return frames\n",
267+
"def load_video(video_path, max_frames_num):\n",
268+
" if type(video_path) == str:\n",
269+
" vr = VideoReader(video_path, ctx=cpu(0))\n",
270+
" else:\n",
271+
" vr = VideoReader(video_path[0], ctx=cpu(0))\n",
272+
" total_frame_num = len(vr)\n",
273+
" uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)\n",
274+
" frame_idx = uniform_sampled_frames.tolist()\n",
275+
" spare_frames = vr.get_batch(frame_idx).asnumpy()\n",
276+
" return spare_frames # (frames, height, width, channels)\n",
280277
"\n",
281278
"\n",
282279
"# Load and process video\n",
283280
"video_path = \"jobs.mp4\"\n",
284-
"video_frames = extract_frames(video_path)\n",
285-
"image_tensors = process_images(video_frames, image_processor, model.config)\n",
286-
"image_tensors = [_image.to(dtype=torch.float16, device=device) for _image in image_tensors]\n",
281+
"video_frames = load_video(video_path, 16)\n",
282+
"print(video_frames.shape) # (16, 1024, 576, 3)\n",
283+
"image_tensors = []\n",
284+
"frames = image_processor.preprocess(video_frames, return_tensors=\"pt\")[\"pixel_values\"].half().cuda()\n",
285+
"image_tensors.append(frames)\n",
287286
"\n",
288287
"# Prepare conversation input\n",
289288
"conv_template = \"qwen_1_5\"\n",

0 commit comments

Comments
 (0)