66
77import torch
88from diffusers import BitsAndBytesConfig
9- from transformers import AutoModel , AutoProcessor , LlavaNextVideoForConditionalGeneration
9+ from transformers import (
10+ AutoModel ,
11+ AutoProcessor ,
12+ LlavaNextVideoForConditionalGeneration ,
13+ Qwen2_5_VLForConditionalGeneration ,
14+ )
1015import numpy as np
1116
1217# Should be imported after `torch` to avoid compatibility issues.import decord
@@ -22,6 +27,7 @@ class CaptionerType(str, Enum):
2227 """Enum for different types of video captioners."""
2328
2429 LLAVA_NEXT_7B = "llava_next_7b"
30+ QWEN_25_VL = "qwen_25_vl"
2531
2632
2733def create_captioner (captioner_type : CaptionerType , ** kwargs ) -> "MediaCaptioningModel" :
@@ -36,6 +42,8 @@ def create_captioner(captioner_type: CaptionerType, **kwargs) -> "MediaCaptionin
3642 """
3743 if captioner_type == CaptionerType .LLAVA_NEXT_7B :
3844 return TransformersVlmCaptioner (model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" , ** kwargs )
45+ elif captioner_type == CaptionerType .QWEN_25_VL :
46+ return TransformersVlmCaptioner (model_id = "Qwen/Qwen2.5-VL-7B-Instruct" , ** kwargs )
3947 else :
4048 raise ValueError (f"Unsupported captioner type: { captioner_type } " )
4149
@@ -103,6 +111,8 @@ def __init__(
103111 Args:
104112 model_id: HuggingFace model ID for LLaVA-NeXT-Video
105113 device: torch.device to use for the model
114+ use_8bit: Whether to use 8-bit quantization
115+ vlm_instruction: Instruction prompt for the model
106116 """
107117 self .device = torch .device (device or "cuda" if torch .cuda .is_available () else "cpu" )
108118 self .vlm_instruction = vlm_instruction
@@ -151,17 +161,40 @@ def caption(
151161 ).to (self .device )
152162
153163 # Generate caption
154- output_tokens = self .model .generate (** inputs , max_new_tokens = 200 , do_sample = False )
155- output = self .processor .decode (output_tokens [0 ], skip_special_tokens = True )
156- caption_raw = output .split ("ASSISTANT: " )[1 ]
164+ output_tokens = self .model .generate (
165+ ** inputs ,
166+ max_new_tokens = 200 ,
167+ do_sample = False ,
168+ temperature = None ,
169+ )
170+
171+ # Trim the generated tokens to exclude the input tokens
172+ output_tokens_trimmed = [
173+ out_ids [len (in_ids ) :]
174+ for in_ids , out_ids in zip (
175+ inputs .input_ids ,
176+ output_tokens ,
177+ strict = False ,
178+ )
179+ ]
180+
181+ # Decode the generated tokens to text
182+ caption_raw = self .processor .batch_decode (
183+ output_tokens_trimmed ,
184+ skip_special_tokens = True ,
185+ clean_up_tokenization_spaces = False ,
186+ )[0 ]
157187
158188 # Clean up caption
159189 caption = self ._clean_raw_caption (caption_raw ) if clean_caption else caption_raw
190+
160191 return caption
161192
162193 def _load_model (self , model_id : str , use_8bit : bool ) -> None :
163194 if model_id == "llava-hf/LLaVA-NeXT-Video-7B-hf" :
164195 model_cls = LlavaNextVideoForConditionalGeneration
196+ elif model_id == "Qwen/Qwen2.5-VL-7B-Instruct" :
197+ model_cls = Qwen2_5_VLForConditionalGeneration
165198 else :
166199 model_cls = AutoModel
167200
@@ -174,7 +207,7 @@ def _load_model(self, model_id: str, use_8bit: bool) -> None:
174207 device_map = self .device .type ,
175208 )
176209
177- self .processor = AutoProcessor .from_pretrained (model_id )
210+ self .processor = AutoProcessor .from_pretrained (model_id , use_fast = True )
178211
179212
180213def example () -> None :
@@ -184,9 +217,12 @@ def example() -> None:
184217 print (f"Usage: python { sys .argv [0 ]} <video_path>" ) # noqa: T201
185218 sys .exit (1 )
186219
187- model = TransformersVlmCaptioner ()
188- caption = model .caption (sys .argv [1 ])
189- print (caption ) # noqa: T201
220+ # Example using both captioner types
221+ for captioner_type in [CaptionerType .LLAVA_NEXT_7B , CaptionerType .QWEN_25_VL ]:
222+ print (f"\n Using { captioner_type } captioner:" ) # noqa: T201
223+ model = create_captioner (captioner_type )
224+ caption = model .caption (sys .argv [1 ])
225+ print (f"CAPTION: { caption } " ) # noqa: T201
190226
191227
192228if __name__ == "__main__" :
0 commit comments