Skip to content

Commit 366f59d

Browse files
Naomi-Ken-Koremmatanby
authored andcommitted
Add qwen-2.5-vl as default captioner.
1 parent b66830d commit 366f59d

File tree

3 files changed

+82
-14
lines changed

3 files changed

+82
-14
lines changed

README.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,28 @@ Use the directory containing your video clips (either from step 1, or your own c
104104

105105
```bash
106106
# Generate captions for all videos in the scenes directory
107+
python scripts/caption_videos.py scenes_output_dir/ \
108+
--output scenes_output_dir/captions.json
109+
```
110+
111+
By default, the script uses the Qwen2.5-VL model for captioning. If you're running into VRAM issues:
112+
113+
1. Try enabling 8-bit quantization to reduce memory usage:
114+
```bash
115+
python scripts/caption_videos.py scenes_output_dir/ \
116+
--output scenes_output_dir/captions.json \
117+
--use-8bit
118+
```
119+
120+
2. If still encountering memory issues, switch to the LLaVA-NeXT model which has lower VRAM requirements:
121+
```bash
107122
python scripts/caption_videos.py scenes_output_dir/ \
108123
--output scenes_output_dir/captions.json \
109-
--captioner-type llava_next_7b
124+
--captioner-type llava_next_7b \
125+
--use-8bit
110126
```
111127

112-
This will create a captions.json file which contains video paths and their captions
128+
This will create a captions.json file which contains video paths and their captions.
113129
This JSON file will be used as input for the data preprocessing step.
114130

115131
#### 3. Dataset Preprocessing (`preprocess_dataset.py`)

scripts/caption_videos.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,10 +388,11 @@ def main( # noqa: PLR0913
388388
help="Path to output file for captions. Format determined by file extension.",
389389
),
390390
captioner_type: CaptionerType = typer.Option( # noqa: B008
391-
CaptionerType.LLAVA_NEXT_7B,
391+
CaptionerType.QWEN_25_VL,
392392
"--captioner-type",
393393
"-c",
394-
help="Type of captioner to use",
394+
help="Type of captioner to use. Valid values: 'llava_next_7b', 'qwen_25_vl'",
395+
case_sensitive=False,
395396
),
396397
device: str | None = typer.Option(
397398
None,
@@ -441,8 +442,23 @@ def main( # noqa: PLR0913
441442
) -> None:
442443
"""Auto-caption videos and images using vision-language models.
443444
444-
This tool can process individual video/image files or directories of media files and generate
445-
captions using a vision-language model. The captions can be saved in various formats.
445+
This script supports both LLaVA-NeXT and Qwen2.5-VL models for generating captions.
446+
The paths in the output file will be relative to the output file's directory.
447+
448+
Examples:
449+
# Caption using LLaVA-NeXT (default)
450+
caption_videos.py video.mp4 -o captions.txt
451+
452+
# Caption using Qwen2.5-VL
453+
caption_videos.py video.mp4 -o captions.txt -c qwen_25_vl
454+
455+
# Caption with custom instruction (especially useful for Qwen)
456+
caption_videos.py video.mp4 -o captions.txt -c qwen_25_vl -i "Describe this video in detail"
457+
458+
Valid captioner types:
459+
qwen_25_vl: Qwen2.5-VL-7B model (default)
460+
llava_next_7b: LLaVA-NeXT-7B model (default)
461+
446462
"""
447463

448464
# Determine device

src/ltxv_trainer/captioning.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66

77
import torch
88
from diffusers import BitsAndBytesConfig
9-
from transformers import AutoModel, AutoProcessor, LlavaNextVideoForConditionalGeneration
9+
from transformers import (
10+
AutoModel,
11+
AutoProcessor,
12+
LlavaNextVideoForConditionalGeneration,
13+
Qwen2_5_VLForConditionalGeneration,
14+
)
1015
import numpy as np
1116

1217
# Should be imported after `torch` to avoid compatibility issues.import decord
@@ -22,6 +27,7 @@ class CaptionerType(str, Enum):
2227
"""Enum for different types of video captioners."""
2328

2429
LLAVA_NEXT_7B = "llava_next_7b"
30+
QWEN_25_VL = "qwen_25_vl"
2531

2632

2733
def create_captioner(captioner_type: CaptionerType, **kwargs) -> "MediaCaptioningModel":
@@ -36,6 +42,8 @@ def create_captioner(captioner_type: CaptionerType, **kwargs) -> "MediaCaptionin
3642
"""
3743
if captioner_type == CaptionerType.LLAVA_NEXT_7B:
3844
return TransformersVlmCaptioner(model_id="llava-hf/LLaVA-NeXT-Video-7B-hf", **kwargs)
45+
elif captioner_type == CaptionerType.QWEN_25_VL:
46+
return TransformersVlmCaptioner(model_id="Qwen/Qwen2.5-VL-7B-Instruct", **kwargs)
3947
else:
4048
raise ValueError(f"Unsupported captioner type: {captioner_type}")
4149

@@ -103,6 +111,8 @@ def __init__(
103111
Args:
104112
model_id: HuggingFace model ID for LLaVA-NeXT-Video
105113
device: torch.device to use for the model
114+
use_8bit: Whether to use 8-bit quantization
115+
vlm_instruction: Instruction prompt for the model
106116
"""
107117
self.device = torch.device(device or "cuda" if torch.cuda.is_available() else "cpu")
108118
self.vlm_instruction = vlm_instruction
@@ -151,17 +161,40 @@ def caption(
151161
).to(self.device)
152162

153163
# Generate caption
154-
output_tokens = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
155-
output = self.processor.decode(output_tokens[0], skip_special_tokens=True)
156-
caption_raw = output.split("ASSISTANT: ")[1]
164+
output_tokens = self.model.generate(
165+
**inputs,
166+
max_new_tokens=200,
167+
do_sample=False,
168+
temperature=None,
169+
)
170+
171+
# Trim the generated tokens to exclude the input tokens
172+
output_tokens_trimmed = [
173+
out_ids[len(in_ids) :]
174+
for in_ids, out_ids in zip(
175+
inputs.input_ids,
176+
output_tokens,
177+
strict=False,
178+
)
179+
]
180+
181+
# Decode the generated tokens to text
182+
caption_raw = self.processor.batch_decode(
183+
output_tokens_trimmed,
184+
skip_special_tokens=True,
185+
clean_up_tokenization_spaces=False,
186+
)[0]
157187

158188
# Clean up caption
159189
caption = self._clean_raw_caption(caption_raw) if clean_caption else caption_raw
190+
160191
return caption
161192

162193
def _load_model(self, model_id: str, use_8bit: bool) -> None:
163194
if model_id == "llava-hf/LLaVA-NeXT-Video-7B-hf":
164195
model_cls = LlavaNextVideoForConditionalGeneration
196+
elif model_id == "Qwen/Qwen2.5-VL-7B-Instruct":
197+
model_cls = Qwen2_5_VLForConditionalGeneration
165198
else:
166199
model_cls = AutoModel
167200

@@ -174,7 +207,7 @@ def _load_model(self, model_id: str, use_8bit: bool) -> None:
174207
device_map=self.device.type,
175208
)
176209

177-
self.processor = AutoProcessor.from_pretrained(model_id)
210+
self.processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
178211

179212

180213
def example() -> None:
@@ -184,9 +217,12 @@ def example() -> None:
184217
print(f"Usage: python {sys.argv[0]} <video_path>") # noqa: T201
185218
sys.exit(1)
186219

187-
model = TransformersVlmCaptioner()
188-
caption = model.caption(sys.argv[1])
189-
print(caption) # noqa: T201
220+
# Example using both captioner types
221+
for captioner_type in [CaptionerType.LLAVA_NEXT_7B, CaptionerType.QWEN_25_VL]:
222+
print(f"\nUsing {captioner_type} captioner:") # noqa: T201
223+
model = create_captioner(captioner_type)
224+
caption = model.caption(sys.argv[1])
225+
print(f"CAPTION: {caption}") # noqa: T201
190226

191227

192228
if __name__ == "__main__":

0 commit comments

Comments
 (0)