diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index ec736aa236ff..f958a150a84b 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -562,7 +562,10 @@ Specified using `--task generate`. | `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* | +| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` +| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + IE+ + VE+ | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + IE+ + VE+ | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | | ✅︎ | ✅︎\* | +| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 825abeaf7e75..4bc622151d11 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1076,6 +1076,116 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ) +# Qwen3-VL-Dense +def run_qwen3_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Qwen/Qwen3-VL-4B-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# Qwen3-VL-MOE +def run_qwen3_vl_moe(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Qwen/Qwen3-VL-30B-A3B-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +def run_qwen3_omni_moe(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # SkyworkR1V def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1146,11 +1256,20 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, "qwen2_5_omni": run_qwen2_5_omni, + "qwen3_vl": run_qwen3_vl, + "qwen3_vl_moe": run_qwen3_vl_moe, + "qwen3_omni_moe": run_qwen3_omni_moe, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, "tarsier": run_tarsier, } +MODELS_NEED_VIDEO_METADATA = [ + "glm4_1v", + "qwen3_vl", + "qwen3_vl_moe", +] + def get_multi_modal_input(args): """ @@ -1176,12 +1295,13 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question + needs_metadata = args.model_type in MODELS_NEED_VIDEO_METADATA video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata vid_questions = ["Why is this video funny?"] return { - "data": [(video, metadata)] if args.model_type == "glm4_1v" else video, + "data": [(video, metadata)] if needs_metadata else video, "questions": vid_questions, } diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index e52a42a2e452..84f2b7cd368d 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -30,6 +30,7 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: """ # Ensure video metadata is included if "video" in mm_data: + # GLM4.1V doesn't support multiple videos video = mm_data["video"] mm_data["video"] = (video, { "total_num_frames": len(video), @@ -39,6 +40,33 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: }) return mm_data +def qwen3_vl_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict: + """ + Patch the multimodal data for Qwen3-VL model. + """ + + def create_metadata(frames: np.ndarray): + num_frames = len(frames) + return { + "total_num_frames": num_frames, + "fps": 2.0, + "duration": num_frames / 2.0, + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + "do_sample_frames": True, + } + + # Ensure video metadata is included + if "video" in mm_data: + video = mm_data["video"] + if isinstance(video, list): + # multiple videos + mm_data["video"] = [(vid, create_metadata(vid)) for vid in video] + else: + # single video + mm_data["video"] = (video, create_metadata(video)) + return mm_data + def _test_processing_correctness( model_id: str, @@ -171,8 +199,10 @@ def _test_processing_correctness( } MM_DATA_PATCHES = { - # GLM4.1V requires video metadata to be included in the input + # GLM4.1V and Qwen3-VL requires video metadata to be included in the input "glm4v": glm4_1v_patch_mm_data, + "qwen3_vl": qwen3_vl_patch_mm_data, + "qwen3_vl_moe": qwen3_vl_patch_mm_data, } @@ -304,6 +334,8 @@ def _test_processing_correctness_one( "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2.5-Omni-3B", + "Qwen/Qwen3-VL-4B-Instruct", + "Qwen/Qwen3-VL-30B-A3B-Instruct", "Skywork/Skywork-R1V-38B", "fixie-ai/ultravox-v0_5-llama-3_2-1b", "openai/whisper-large-v3", diff --git a/tests/models/registry.py b/tests/models/registry.py index c90aed31f597..ac7545bd42c9 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -404,6 +404,15 @@ def check_available_online( "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501 "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 + "Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501 + max_model_len=4096, + min_transformers_version="4.57"), # noqa: E501 + "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501 + max_model_len=4096, + min_transformers_version="4.57"), + "Qwen3OmniMoeThinkerForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-Omni-30B-A3B-Instruct", # noqa: E501 + max_model_len=4096, + min_transformers_version="4.57"), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501 diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 16412121cf0a..93a4733d46b5 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -77,7 +77,7 @@ def video_to_pil_images_list(path: str, ] -def video_get_metadata(path: str) -> dict[str, Any]: +def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]: cap = cv2.VideoCapture(path) if not cap.isOpened(): raise ValueError(f"Could not open video file {path}") @@ -86,11 +86,18 @@ def video_get_metadata(path: str) -> dict[str, Any]: fps = cap.get(cv2.CAP_PROP_FPS) duration = total_frames / fps if fps > 0 else 0 + if num_frames == -1 or num_frames > total_frames: + num_frames = total_frames + metadata = { "total_num_frames": total_frames, "fps": fps, "duration": duration, - "video_backend": "opencv" + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + # extra field used to control hf processor's video + # sampling behavior + "do_sample_frames": num_frames == total_frames, } return metadata @@ -126,7 +133,7 @@ def np_ndarrays(self) -> npt.NDArray: @property def metadata(self) -> dict[str, Any]: video_path = download_video_asset(self.filename) - ret = video_get_metadata(video_path) + ret = video_get_metadata(video_path, self.num_frames) return ret def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray: diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 0ce8ab650d06..6050a5a368a5 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -24,6 +24,14 @@ from vllm.v1.attention.backends.utils import validate_kv_sharing_target +def check_upstream_fa_availability(dtype: torch.dtype): + if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda( + ) and current_platform.has_device_capability(80): + from transformers.utils import is_flash_attn_2_available + return is_flash_attn_2_available() + return False + + class Attention(nn.Module): """Attention layer. diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 73657666dfdd..783673aa4d96 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -537,6 +537,8 @@ def _placeholder_str(self, modality: ModalityStr, return "<|vision_start|><|image_pad|><|vision_end|>" if model_type == "qwen2_5_omni": return "<|vision_start|><|IMAGE|><|vision_end|>" + if model_type in ("qwen3_omni_moe", "qwen3_vl_moe", "qwen3_vl"): + return "<|vision_start|><|image_pad|><|vision_end|>" if model_type == "molmo": return "" if model_type == "aria": @@ -555,6 +557,8 @@ def _placeholder_str(self, modality: ModalityStr, if model_type in ("qwen2_audio", "qwen2_5_omni"): return (f"Audio {current_count}: " f"<|audio_bos|><|AUDIO|><|audio_eos|>") + if model_type == "qwen3_omni_moe": + return "<|audio_start|><|audio_pad|><|audio_end|>" if model_type == "minicpmo": return "()" raise TypeError(f"Unknown model type: {model_type}") @@ -567,6 +571,8 @@ def _placeholder_str(self, modality: ModalityStr, return "<|vision_start|><|video_pad|><|vision_end|>" if model_type == "qwen2_5_omni": return "<|vision_start|><|VIDEO|><|vision_end|>" + if model_type in ("qwen3_omni_moe", "qwen3_vl_moe", "qwen3_vl"): + return "<|vision_start|><|video_pad|><|vision_end|>" if model_type in ("minicpmo", "minicpmv"): return "()" if model_type.startswith("llava"): diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index ab237a460dd2..2500eab09ce7 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -1168,6 +1168,18 @@ def forward_hpu( # type: ignore[override] return query_out.type_as(query), key_out.type_as(key) +def apply_interleaved_rope(x: torch.Tensor, + mrope_section: list[int]) -> torch.Tensor: + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + """ + x_t = x[0].clone() + x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3] + x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3] + return x_t + + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -1180,6 +1192,7 @@ def __init__( is_neox_style: bool, dtype: torch.dtype, mrope_section: Optional[list[int]] = None, + mrope_interleaved: Optional[bool] = False, ) -> None: # In Qwen2.5-VL, the maximum index value is related to the duration of # the input video. We enlarge max_position_embeddings to 4 times to get @@ -1189,6 +1202,7 @@ def __init__( base, is_neox_style, dtype) self.mrope_section = mrope_section + self.mrope_interleaved = mrope_interleaved if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 @@ -1215,17 +1229,20 @@ def forward( cos, sin = cos_sin.chunk(2, dim=-1) if positions.ndim == 2: assert self.mrope_section - - cos = torch.cat([ - m[i] - for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) - ], - dim=-1) - sin = torch.cat([ - m[i] - for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) - ], - dim=-1) + if self.mrope_interleaved: + cos = apply_interleaved_rope(cos, self.mrope_section) + sin = apply_interleaved_rope(sin, self.mrope_section) + else: + cos = torch.cat([ + m[i] for i, m in enumerate( + cos.split(self.mrope_section, dim=-1)) + ], + dim=-1) + sin = torch.cat([ + m[i] for i, m in enumerate( + sin.split(self.mrope_section, dim=-1)) + ], + dim=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) @@ -1292,7 +1309,7 @@ def get_input_positions_tensor( ) -> tuple[torch.Tensor, int]: from vllm.transformers_utils.config import thinker_uses_mrope if thinker_uses_mrope(hf_config) and \ - hf_config.model_type == "qwen2_5_omni": + hf_config.model_type in ["qwen2_5_omni","qwen3_omni_moe"]: return cls._omni_get_input_positions_tensor( input_tokens=input_tokens, hf_config=hf_config, @@ -1313,6 +1330,15 @@ def get_input_positions_tensor( context_len=context_len, seq_len=seq_len, ) + elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + return cls._qwen3vl_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) else: return cls._vl_get_input_positions_tensor( input_tokens=input_tokens, @@ -1532,6 +1558,386 @@ def _vl_get_input_positions_tensor( return llm_positions, mrope_position_delta + @classmethod + def _qwen3vl_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value.""" + + video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw + for _ in range(t)] + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions, mrope_position_delta + + @classmethod + def _omni3_get_input_positions_tensor( + cls, + config, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + audio_seqlens: Optional[torch.LongTensor] = None, + second_per_grids: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + + def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor): + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - + 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + spatial_merge_size = config.vision_config.spatial_merge_size + image_token_id = config.image_token_id + video_token_id = config.video_token_id + audio_token_id = config.audio_token_id + vision_start_token_id = config.vision_start_token_id + audio_start_token_id = config.audio_start_token_id + position_id_per_seconds = config.position_id_per_seconds + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None + or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.zeros( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_idx, video_idx, audio_idx = 0, 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums, audio_nums = 0, 0, 0 + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ((vision_tokens == audio_start_token_id).sum() + if use_audio_in_video else + (vision_tokens == video_token_id).sum()) + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = (image_nums, + video_nums, + audio_nums) + multimodal_nums = (image_nums + audio_nums + if use_audio_in_video else image_nums + + video_nums + audio_nums) + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if (image_token_id in input_tokens or video_token_id + in input_tokens) and (remain_videos > 0 + or remain_images > 0): + ed_vision_start = input_tokens.index( + vision_start_token_id, st) + else: + ed_vision_start = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio_start = input_tokens.index( + audio_start_token_id, st) + else: + ed_audio_start = len(input_tokens) + 1 + min_ed = min(ed_vision_start, ed_audio_start) + if min_ed == ed_audio_start: + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].long().max( + ) + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand( + 3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + + st_idx) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + audio_len = _get_feat_extract_output_lengths( + audio_seqlens[audio_idx]) + llm_pos_ids = torch.arange(audio_len).view( + 1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + + st_idx) + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + elif min_ed == ed_vision_start and input_ids[ + ed_vision_start + 1] == image_token_id: + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].long().max( + ) + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand( + 3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + + st_idx) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = ((torch.arange(grid_t)) * 1 * + position_id_per_seconds) + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, + grid_hs, grid_ws) + image_len = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + + st_idx) + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + elif min_ed == ed_vision_start and input_ids[ + ed_vision_start + + 1] == video_token_id and not use_audio_in_video: + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].long().max( + ) + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand( + 3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + + st_idx) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ((torch.arange(grid_t)) * + second_per_grids[video_idx].cpu() * + position_id_per_seconds) + llm_pos_ids = cls._get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, + grid_hs, grid_ws) + video_len = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + + st_idx) + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + elif (min_ed == ed_vision_start + and ed_vision_start + 1 == ed_audio_start + and use_audio_in_video): + text_len = min_ed - st + if text_len != 0: + st_idx = llm_pos_ids_list[-1].long().max( + ) + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand( + 3, -1) + st_idx) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + + st_idx) + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + + st_idx) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + audio_len = _get_feat_extract_output_lengths( + audio_seqlens[audio_idx]) + audio_llm_pos_ids = torch.arange(audio_len).view( + 1, -1).expand(3, -1) + st_idx + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ((torch.arange(grid_t)) * + second_per_grids[video_idx].cpu() * + position_id_per_seconds) + video_llm_pos_ids = cls._get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, + grid_hs, grid_ws) + video_data_index, audio_data_index = 0, 0 + while (video_data_index < video_llm_pos_ids.shape[-1] + and audio_data_index + < audio_llm_pos_ids.shape[-1]): + if video_llm_pos_ids[0][ + video_data_index] <= audio_llm_pos_ids[0][ + audio_data_index]: + llm_pos_ids_list.append( + video_llm_pos_ids[:, video_data_index: + video_data_index + 1]) + video_data_index += 1 + else: + llm_pos_ids_list.append( + audio_llm_pos_ids[:, audio_data_index: + audio_data_index + 1]) + audio_data_index += 1 + if video_data_index < video_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + video_llm_pos_ids[:, video_data_index: + video_llm_pos_ids.shape[-1]]) + if audio_data_index < audio_llm_pos_ids.shape[-1]: + llm_pos_ids_list.append( + audio_llm_pos_ids[:, audio_data_index: + audio_llm_pos_ids.shape[-1]]) + video_len = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2) + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + + st_idx) + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + + st_idx) + st += (text_len + bos_len * 2 + audio_len + video_len + + eos_len * 2) + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].long().max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx) + llm_positions = torch.cat(llm_pos_ids_list, + dim=1).reshape(3, -1) + + position_ids[..., i, + attention_mask[i] == 1] = llm_positions.to( + position_ids.device) + mrope_position_deltas.append(llm_positions.long().max() + 1 - + len(input_ids)) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas.long() + else: + position_ids = attention_mask.cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to( + attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum( + attention_mask, dim=-1, keepdim=True) + return position_ids, mrope_position_deltas.long() + @classmethod def _omni_get_input_positions_tensor( cls, @@ -1563,8 +1969,25 @@ def _omni_get_input_positions_tensor( # TODO(fyabc): refactor and share more code with # _vl_get_input_positions_tensor. - + model_type = hf_config.model_type thinker_config = hf_config.thinker_config + + if isinstance(image_grid_thw, list): + image_grid_thw = torch.tensor(image_grid_thw) + if isinstance(video_grid_thw, list): + video_grid_thw = torch.tensor(video_grid_thw) + + if "qwen3_omni" in model_type: + (llm_positions, + mrope_position_delta) = cls._omni3_get_input_positions_tensor( + thinker_config, torch.tensor([input_tokens]), image_grid_thw, + video_grid_thw, None, use_audio_in_video, + audio_feature_lengths, + torch.tensor([1] * torch.tensor(video_grid_thw).shape[0])) + llm_positions = llm_positions.squeeze(1) + mrope_position_delta = mrope_position_delta.squeeze() + return llm_positions, mrope_position_delta + audio_token_id = thinker_config.audio_token_index image_token_id = thinker_config.image_token_index video_token_id = thinker_config.video_token_index @@ -1577,11 +2000,6 @@ def _omni_get_input_positions_tensor( tokens_per_second = getattr(thinker_config.vision_config, "tokens_per_second", 25) - if isinstance(image_grid_thw, list): - image_grid_thw = torch.tensor(image_grid_thw) - if isinstance(video_grid_thw, list): - video_grid_thw = torch.tensor(video_grid_thw) - src_item = input_tokens audio_seqlens = audio_feature_lengths if not second_per_grid_ts: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8e6c76b5a564..b54f187f8e7f 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -299,7 +299,7 @@ def __init__(self, decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer): super().__init__() - config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_config.get_text_config() cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index d8d324d7cd16..af2de72e68ab 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -87,7 +87,7 @@ from habana_frameworks.torch.hpex.kernels import FusedSDPA # For profile run -_MAX_FRAMES_PER_VIDEO = 16 +_MAX_FRAMES_PER_VIDEO = 600 # === Vision Inputs === # diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index ca509df0468c..881c9d1b418c 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -321,7 +321,7 @@ class Qwen3MoeModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - config = vllm_config.model_config.hf_config + config = vllm_config.model_config.hf_config.get_text_config() cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py new file mode 100644 index 000000000000..1ae38c81b4cb --- /dev/null +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -0,0 +1,1545 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3-Omni-Moe model (thinker part).""" + +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( + Qwen3OmniMoeConfig, Qwen3OmniMoeThinkerConfig) +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeAudioEncoder) +from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import ( + Qwen3OmniMoeProcessor) +from transformers.models.whisper import WhisperFeatureExtractor + +from vllm.attention.layer import check_upstream_fa_availability +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen2_audio import (Qwen2AudioInputs, + Qwen2AudioProcessingInfo) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalKwargsItems, NestedTensors +from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + PlaceholderFeaturesInfo, + PromptReplacement, PromptUpdate) +from vllm.platforms import _Backend, current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import decode_tokens + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .qwen2_5_omni_thinker import (Qwen2_5OmniConditionalGenerationMixin, + Qwen2_5OmniThinkerDummyInputsBuilder, + Qwen2_5OmniThinkerMultiModalProcessor, + Qwen2_5OmniThinkerProcessingInfo) +from .qwen2_5_vl import (Qwen2_5_VisionAttention, + Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VLImageInputs, + Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoInputs) +from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix, + merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +try: + import flash_attn +except (ImportError, ModuleNotFoundError): + flash_attn = None + +logger = init_logger(__name__) +is_hpu = current_platform.is_hpu() + +if is_hpu: + import habana_frameworks.torch as htorch + import habana_frameworks.torch.core as htcore + + +class Qwen3_VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d(in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.hidden_size) + return x + + +class Qwen3_VisionMLP(nn.Module): + + def __init__(self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.linear_fc1 = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc1") + self.linear_fc2 = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc2") + self.act_fn = act_fn + + def forward(self, x: torch.Tensor): + mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return mlp_output + + +class Qwen3_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.mlp = Qwen3_VisionMLP(dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + attn_mask=attn_mask) + + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen3_VisionPatchMerger(nn.Module): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + spatial_merge_size: int = 2, + use_postshuffle_norm: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + + self.use_postshuffle_norm = use_postshuffle_norm + if self.use_postshuffle_norm: + context_dim = self.hidden_size + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.use_postshuffle_norm = use_postshuffle_norm + self.ln_q = norm_layer( + self.hidden_size if use_postshuffle_norm else context_dim) + self.mlp = nn.ModuleList([ + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), + nn.GELU(), + RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_postshuffle_norm: + x = self.ln_q(x.view(-1, self.hidden_size)) + else: + x = self.ln_q(x).view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen3Omni_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.image_size = vision_config.image_size + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.temporal_patch_size = vision_config.temporal_patch_size + self.apply_vit_abs_pos_embed = vision_config.apply_vit_abs_pos_embed + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + # vit pos embedding, TODO: spatial_patch_size vs patch_size + if self.apply_vit_abs_pos_embed: + self.pos_embed = nn.Embedding( + (self.image_size // self.patch_size)**2, self.hidden_size) + else: + self.pos_embed = nn.Parameter( + torch.empty([ + 1, (self.image_size // self.patch_size)**2, + self.hidden_size + ])) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(vision_config.depth) + ]) + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + if self.deepstack_visual_indexes is not None: + self.merger_list = nn.ModuleList([ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.merger_list.{layer_idx}") + for layer_idx in range(len(self.deepstack_visual_indexes)) + ]) + + self.attn_backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + num_grid_per_side = self.image_size // self.patch_size + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + # TODO: use torch instand of np + for t, h, w in grid_thw: + h_idxs = np.linspace(0, num_grid_per_side - 1, h) + w_idxs = np.linspace(0, num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.astype(int) + w_idxs_floor = w_idxs.astype(int) + h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - + 1) + w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - + 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T + + w_idxs_floor[None]).flatten().tolist() * t) + idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T + + w_idxs_ceil[None]).flatten().tolist() * t) + idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T + + w_idxs_floor[None]).flatten().tolist() * t) + idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T + + w_idxs_ceil[None]).flatten().tolist() * t) + + weight_list[0].extend( + ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t) + weight_list[1].extend( + ((1 - dh)[None].T * dw[None]).flatten().tolist() * t) + weight_list[2].extend( + (dh[None].T * (1 - dw)[None]).flatten().tolist() * t) + weight_list[3].extend( + (dh[None].T * dw[None]).flatten().tolist() * t) + + device = self.pos_embed.weight.device + dtype = self.pos_embed.weight.dtype + + p0 = self.pos_embed( + torch.tensor( + idx_list[0], dtype=torch.long, device=device)) * torch.tensor( + weight_list[0], dtype=dtype, device=device)[:, None] + p1 = self.pos_embed( + torch.tensor( + idx_list[1], dtype=torch.long, device=device)) * torch.tensor( + weight_list[1], dtype=dtype, device=device)[:, None] + p2 = self.pos_embed( + torch.tensor( + idx_list[2], dtype=torch.long, device=device)) * torch.tensor( + weight_list[2], dtype=dtype, device=device)[:, None] + p3 = self.pos_embed( + torch.tensor( + idx_list[3], dtype=torch.long, device=device)) * torch.tensor( + weight_list[3], dtype=dtype, device=device)[:, None] + + patch_pos_embeds = p0 + p1 + p2 + p3 + patch_pos_embeds = patch_pos_embeds.split( + [t * h * w for t, h, w in grid_thw]) + patch_pos_embeds_permute = [] + m_size = self.spatial_merge_size + for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): + pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size, + m_size, -1).permute(0, 1, 3, 2, 4, + 5).flatten(0, 4) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def compute_attn_mask_seqlen( + self, + cu_seqlens: torch.Tensor, + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + if self.apply_vit_abs_pos_embed: + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=grid_thw.dtype + if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + hidden_states = hidden_states.unsqueeze(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + hidden_states_list = [] + deepstack_visual_indexes = self.deepstack_visual_indexes + + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens) + if (deepstack_visual_indexes is not None + and layer_num in deepstack_visual_indexes): + hidden_states_list.append(hidden_states) + + hidden_states = self.merger(hidden_states) + + # processing deepstack + if deepstack_visual_indexes is not None: + processed_hidden_states_list = [hidden_states] + for idx, x in enumerate(hidden_states_list): + x = self.merger_list[idx](x) + processed_hidden_states_list.append(x) + # we cat the original visual features and deepstack + # features along the feature dim + hidden_states = torch.cat( + processed_hidden_states_list, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3Omni_VisionTransformerStaticShape(Qwen3Omni_VisionTransformer): + """ + Here we overwrite some of the methods of Qwen3Omni_VisionTransformer + to make the model more friendly to static shapes. Specifically, + we split the forward method into: + - pre_attn (dynamic) + - forward (static shape) + and we should call get_image_embeds instead of forward, allowing + the forward method ro run with HPU_Graphs, whereas the + pre_attn and post_attn methods are allow to be dynamic. + """ + + def pad_multimodal_data(self, + pixel_values, + vision_buckets, + constant_value=0): + + desired_number_of_pixels = vision_buckets.get_multimodal_bucket( + pixel_values.shape[0]) + padding_len = desired_number_of_pixels - pixel_values.shape[0] + if padding_len <= 0: + return pixel_values + + logger_msg = "Padding current number pixel " \ + + str(pixel_values.shape[0]) \ + + " to " \ + + str(desired_number_of_pixels) + logger.debug(logger_msg) + + pixel_values = F.pad(pixel_values, (0, 0, 0, padding_len), "constant", + constant_value) + + return pixel_values + + def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor, + vision_buckets): + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + attention_mask = torch.ones(hidden_states.size(0), + 1).to(device=self.device) + + hidden_states = self.pad_multimodal_data(hidden_states, vision_buckets, + 0) + rotary_pos_emb = self.pad_multimodal_data(rotary_pos_emb, + vision_buckets, -100) + attention_mask = self.pad_multimodal_data(attention_mask, + vision_buckets, 0) + + return hidden_states, rotary_pos_emb, attention_mask + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor], + rotary_pos_emb: torch.Tensor, + ) -> torch.Tensor: + hidden_states = x.unsqueeze(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + hidden_states_list = [] + deepstack_visual_indexes = self.deepstack_visual_indexes + + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + attn_mask=attn_mask, + cu_seqlens=None, + ) + if (deepstack_visual_indexes is not None + and layer_num in deepstack_visual_indexes): + hidden_states_list.append(hidden_states) + + hidden_states = self.merger(hidden_states) + + # processing deepstack + if deepstack_visual_indexes is not None: + processed_hidden_states_list = [hidden_states] + for idx, x in enumerate(hidden_states_list): + x = self.merger_list[idx](x) + processed_hidden_states_list.append(x) + # we cat the original visual features + # and deepstack features along the feature dim + hidden_states = torch.cat( + processed_hidden_states_list, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + + return hidden_states + + def get_image_embeds( + self, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + vision_buckets, + ) -> torch.Tensor: + + offset = 0 + results = [] + # process each image one by one + for img_idx in range(grid_thw.shape[0]): + img_shape = grid_thw[img_idx, :].unsqueeze(0).clone() + # For video, we process frames separately + grid_t = grid_thw[img_idx, 0] + img_shape[0, 0] = 1 + curr_img_size = img_shape.prod() + for _ in torch.arange(0, grid_t): + pixel_values_curr_img = pixel_values[offset:offset + + curr_img_size, :] + + offset += curr_img_size + + (pixel_values_curr_img_padded, rot_pos_emb, + attention_mask) = self.pre_attn(pixel_values_curr_img, + img_shape, vision_buckets) + + fullatt_block_attn_mask = \ + attention_mask.squeeze(1).unsqueeze(0) * attention_mask + + extra_forward_kwargs = {} + if htorch.utils.internal.is_lazy(): + padded_len = pixel_values_curr_img_padded.shape[0] + use_graph = vision_buckets.use_graph(padded_len) + extra_forward_kwargs.update( + {"bypass_hpu_graphs": not use_graph}) + + htcore.mark_step() + hidden_states = self.forward(pixel_values_curr_img_padded, + rotary_pos_emb=rot_pos_emb, + attn_mask=fullatt_block_attn_mask, + **extra_forward_kwargs) + htcore.mark_step() + + post_embed_size = curr_img_size // self.spatial_merge_unit + results += [hidden_states[:post_embed_size, :]] + + results_cat = torch.concat(results) + image_embeds = results_cat + return image_embeds + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + "deepstack_input_embeds": 0 + }) +class Qwen3MoeLLMModel(Qwen3MoeModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + self.deepstack_multiscale_layer_start = 1 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + deepstack_input_embeds: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for layer_idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + layer_idx = layer_idx + self.start_layer + + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if deepstack_input_embeds is not None and \ + layer_idx in range(0, len(deepstack_input_embeds)): + hidden_states = hidden_states + deepstack_input_embeds[ + f"deepstack_input_embeds_{layer_idx}"] + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3MoeForCausalLM, self).__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Qwen3MoeLLMModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + +class Qwen3OmniMoeThinkerProcessingInfo(Qwen2AudioProcessingInfo, + Qwen2_5_VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3OmniMoeConfig).thinker_config + + def get_hf_processor(self, **kwargs: object) -> Qwen3OmniMoeProcessor: + processor = self.ctx.get_hf_processor( + Qwen3OmniMoeProcessor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + if not hasattr(processor, "audio_token"): + processor.audio_token = "<|audio_pad|>" + if not hasattr(processor, "image_token"): + processor.image_token = "<|image_pad|>" + if not hasattr(processor, "video_token"): + processor.video_token = "<|video_pad|>" + return processor + + def get_feature_extractor(self, **kwargs: object): + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None, "image": None, "video": None} + + +class Qwen3OmniMoeThinkerMultiModalProcessor( + Qwen2_5OmniThinkerMultiModalProcessor, ): + + def _get_feat_extract_output_lengths( + self, input_lengths: torch.Tensor) -> torch.Tensor: + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - + 1) // 2 + 1 + (input_lengths // 100) * 13 + return feat_lengths, output_lengths + + def _maybe_apply_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + prompt_ids: list[int], + mm_kwargs: MultiModalKwargsItems, + is_update_applied: bool, + ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: + """ + Qwen3-Omni reimplements this function to handle `use_audio_in_video`. + """ + unbound_prompt_updates = self._get_prompt_updates( + mm_items, + hf_processor_mm_kwargs, + mm_kwargs, + ) + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates) + mm_item_counts = mm_items.get_all_counts() + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + use_audio_in_video = (all( + item["use_audio_in_video"].data + for item in mm_kwargs["video"]) if "video" in mm_kwargs else False) + + if use_audio_in_video and "video" in mm_item_counts: + assert "audio" in mm_item_counts + mm_item_counts["audio"] -= mm_item_counts["video"] + + if is_update_applied: + prompt_ids = self._get_raw_input_ids(prompt_ids, + use_audio_in_video) + + ( + prompt_ids, + prompt, + mm_placeholders, + ) = self._apply_prompt_updates( + prompt_ids, + mm_prompt_updates, + mm_item_counts, + ) + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + + tokenizer = self.info.get_tokenizer() + prompt = decode_tokens(tokenizer, prompt_ids) + + return prompt_ids, prompt, mm_placeholders + + def get_updates_use_audio_in_video( + self, + thinker_config: PretrainedConfig, + audio_len: int, + video_grid_thw: Union[list[int], torch.Tensor], + video_second_per_grid_t: float, + ) -> list[int]: + shift = 0 + audio_token_id = thinker_config.audio_token_id + video_token_id = thinker_config.video_token_id + audio_start_token_id = thinker_config.audio_start_token_id + audio_end_token_id = thinker_config.audio_end_token_id + spatial_merge_size = thinker_config.vision_config.spatial_merge_size + position_id_per_seconds = thinker_config.position_id_per_seconds + audio_token_indices = np.arange(next(iter([audio_len]))) + curr_video_grid_thw = next(iter([video_grid_thw])) + height = curr_video_grid_thw[1] // spatial_merge_size + width = curr_video_grid_thw[2] // spatial_merge_size + video_token_indices = np.arange(curr_video_grid_thw[0]).reshape( + -1, 1, 1) + video_token_indices = np.broadcast_to( + video_token_indices, + (video_token_indices.shape[0], height, width)).reshape(-1) + video_token_indices = ((video_token_indices + shift) * + next(iter([video_second_per_grid_t])) * + position_id_per_seconds) + video_data_index, audio_data_index = 0, 0 + updates = [audio_start_token_id] + while video_data_index < len( + video_token_indices) and audio_data_index < len( + audio_token_indices): + if video_token_indices[video_data_index] <= audio_token_indices[ + audio_data_index]: + updates += [video_token_id] + video_data_index += 1 + else: + updates += [audio_token_id] + audio_data_index += 1 + if video_data_index < len(video_token_indices): + updates += [video_token_id + ] * (len(video_token_indices) - video_data_index) + if audio_data_index < len(audio_token_indices): + updates += [audio_token_id + ] * (len(audio_token_indices) - audio_data_index) + updates += [audio_end_token_id] + return updates + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + vocab = tokenizer.get_vocab() + + audio_token = processor.audio_token + image_token = processor.image_token + video_token = processor.video_token + audio_token_id = vocab[audio_token] + image_token_id = vocab[image_token] + video_token_id = vocab[video_token] + + audio_feature_lengths = out_mm_kwargs.get("audio_feature_lengths") + feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") + if audio_feature_lengths is None and feature_attention_mask is None: + audio_output_lengths = [] + elif audio_feature_lengths is not None: + _, audio_output_lens = self._get_feat_extract_output_lengths( + audio_feature_lengths) + audio_output_lengths = audio_output_lens.tolist() + elif feature_attention_mask is not None: + assert isinstance(feature_attention_mask, torch.Tensor) + _, audio_output_lens = self._get_feat_extract_output_lengths( + feature_attention_mask.sum(-1)) + audio_output_lengths = audio_output_lens.tolist() + + # number of audios read from video. + audio_in_video_item_idx = 0 + audio_item_idx = 0 + + def get_replacement_qwen2_audio(item_idx: int): + nonlocal audio_item_idx + item_idx += audio_in_video_item_idx + + audio_item_idx += 1 + + num_features = audio_output_lengths[item_idx] + if num_features == 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + raise ValueError( + f"The audio {audio} (len={len(audio)}) is too short " + "to be represented inside the model") + + return [audio_token_id] * num_features + + def get_replacement_qwen2_vision(item_idx: int, modality: str): + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + merge_length = image_processor.merge_size**2 + + token_id = image_token_id if modality == "image" else video_token_id + return [token_id] * (int(grid_thw.prod()) // merge_length) + + use_audio_in_video = hf_processor_mm_kwargs.get( + "use_audio_in_video", False) + thinker_config = self.info.get_hf_config() + + def get_replacement_qwen2_use_audio_in_video(item_idx: int): + nonlocal audio_in_video_item_idx + audio_num_features = audio_output_lengths[audio_item_idx + + item_idx] + video_grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + + audio_in_video_item_idx += 1 + + second_per_grid_ts = hf_processor_mm_kwargs.get( + "second_per_grid_ts", None) + if second_per_grid_ts: + video_second_per_grid_t = second_per_grid_ts[item_idx] + else: + video_second_per_grid_t = 1.0 + + return self.get_updates_use_audio_in_video( + thinker_config=thinker_config, + audio_len=audio_num_features, + video_grid_thw=video_grid_thw, + video_second_per_grid_t=video_second_per_grid_t, + ) + + video_replacement_fn = ( + get_replacement_qwen2_use_audio_in_video if use_audio_in_video else + partial(get_replacement_qwen2_vision, modality="video")) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_qwen2_audio, + ), + PromptReplacement( + modality="image", + target=image_token, + replacement=partial(get_replacement_qwen2_vision, + modality="image"), + ), + PromptReplacement( + modality="video", + target=video_token, + replacement=video_replacement_fn, + ), + ] + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], + mm_item_counts: Mapping[str, int], + ) -> None: + BaseMultiModalProcessor[ + Qwen2_5OmniThinkerProcessingInfo]._validate_mm_placeholders( + self, mm_placeholders, mm_item_counts) + + def _get_raw_input_ids( + self, + token_ids: list[int], + use_audio_in_video: bool = False, + ) -> list[int]: + + tokenizer = self.info.get_tokenizer() + vision_bos_token = tokenizer.encode(tokenizer.vision_bos_token)[0] + vision_eos_token = tokenizer.encode(tokenizer.vision_eos_token)[0] + audio_bos_token = tokenizer.encode(tokenizer.audio_bos_token)[0] + audio_eos_token = tokenizer.encode(tokenizer.audio_eos_token)[0] + audio_token = tokenizer.encode("<|audio_pad|>")[0] + image_token = tokenizer.encode("<|image_pad|>")[0] + video_token = tokenizer.encode("<|video_pad|>")[0] + + result = token_ids[:] + if use_audio_in_video: + while True: + start = None + for i in range(len(result) - 1): + if result[i:i + 2] == [vision_bos_token, audio_bos_token]: + start = i + break + if start is not None: + end = None + for i in range(start + 2, len(result) - 1): + if result[i:i + + 2] == [audio_eos_token, vision_eos_token]: + end = i + break + if end is not None: + result = result[:start] + [ + vision_bos_token, video_token, vision_eos_token + ] + result[end + 2:] + else: + break + + for mm_token in [audio_token, image_token, video_token]: + compressed = [] + for x in result: + if x != mm_token or (not compressed + or compressed[-1] != mm_token): + compressed.append(x) + result = compressed + + return result + + +class Qwen3OmniMoeConditionalGenerationMixin( + Qwen2_5OmniConditionalGenerationMixin): + + def _validate_and_reshape_mm_tensor(self, + mm_input: object, + name: str, + dim: int = 0) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if name == "feature_attention_mask": + dim = -1 + if isinstance(mm_input, torch.Tensor): + return torch.concat(list(mm_input), dim=dim) + else: + if isinstance(mm_input[0], list): + return torch.concat([ + torch.concat(mm_input[i], dim=dim) + for i in range(len(mm_input)) + ], + dim=dim) + else: + return torch.concat(mm_input, dim=dim) + + def _get_feat_extract_output_lengths( + self, input_lengths: torch.Tensor) -> torch.Tensor: + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - + 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths, output_lengths + + def _process_audio_input( + self, + audio_input: Qwen2AudioInputs, + audio_hashes: list[str] = None, + cached_audio_features: torch.Tensor = None, + ) -> torch.Tensor: + + input_features = audio_input["input_features"] + audio_feature_lengths = audio_input["audio_feature_lengths"] + + if input_features.ndim == 3: + assert input_features.shape[0] == 1 + input_features = input_features.squeeze(0) + + if not isinstance(audio_feature_lengths, torch.Tensor): + audio_feature_lengths = torch.cat(audio_feature_lengths) + if audio_feature_lengths.ndim == 2: + audio_feature_lengths = audio_feature_lengths.reshape(-1) + + audio_feat_lengths, audio_output_lengths = ( + self._get_feat_extract_output_lengths(audio_feature_lengths)) + + audio_outputs = self.audio_tower( + input_features.to(self.audio_tower.dtype), + feature_lens=audio_feature_lengths, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + return audio_features.split(audio_output_lengths.tolist()) + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"].type(self.visual.dtype) + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + if is_hpu: + assert isinstance(self.visual, + Qwen3Omni_VisionTransformerStaticShape) + image_embeds = self.visual.get_image_embeds( + pixel_values, + grid_thw=grid_thw, + vision_buckets=self.vision_buckets, + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs, + video_hashes: list[str] = None, + cached_video_embeds: torch.Tensor = None) -> torch.Tensor: + if video_input["type"] == "video_embeds": + return video_input["video_embeds"].type(self.visual.dtype) + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + if is_hpu: + assert isinstance(self.visual, + Qwen3Omni_VisionTransformerStaticShape) + video_embeds = self.visual.get_image_embeds( + pixel_values_videos, + grid_thw=grid_thw, + vision_buckets=self.vision_buckets, + ) + else: + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen3OmniMoeThinkerMultiModalProcessor, + info=Qwen3OmniMoeThinkerProcessingInfo, + dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder, +) +class Qwen3OmniMoeThinkerForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + Qwen3OmniMoeConditionalGenerationMixin, +): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "thinker.lm_head.": "language_model.lm_head.", + "thinker.model.": "language_model.model.", + "thinker.": "", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|video_pad|><|vision_end|>" + if modality.startswith("audio"): + return "<|audio_start|><|audio_pad|><|audio_end|>" + + raise ValueError("Only image, video or audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + thinker_config: Qwen3OmniMoeThinkerConfig = ( + vllm_config.model_config.hf_config.thinker_config) + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = thinker_config + self.config.architectures = [ + "Qwen3OmniMoeThinkerForConditionalGeneration" + ] + self.text_dim = thinker_config.text_config.hidden_size + self.multimodal_config = multimodal_config + + # force "use_flash_attention_2=True" to audio tower to align + # the results. + if flash_attn is not None: + audio_config = thinker_config.audio_config + audio_config._attn_implementation_autoset = True + audio_config._attn_implementation = "flash_attention_2" + else: + logger.warning( + "flash_attn is not available, the model may not yield the " + "exactly same result as the transformers implementation " + "in the audio tower part.") + + self.audio_tower = Qwen3OmniMoeAudioEncoder( + thinker_config.audio_config) + if is_hpu: + qwen3omni_visionTransformer = Qwen3Omni_VisionTransformerStaticShape + else: + qwen3omni_visionTransformer = Qwen3Omni_VisionTransformer + self.visual = qwen3omni_visionTransformer( + vision_config=thinker_config.vision_config, + norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + ) + self.quant_config = quant_config + + self.language_model = Qwen3MoeLLMForCausalLM( + vllm_config=vllm_config.with_hf_config( + thinker_config.text_config, + architectures=["Qwen3MoeForCausalLM"]), + prefix=maybe_prefix(prefix, "language_model")) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + self.use_deepstack = hasattr(thinker_config.vision_config, + 'deepstack_visual_indexes') + self.deepstack_num_level = len( + thinker_config.vision_config.deepstack_visual_indexes + ) if self.use_deepstack else 0 + # register buffer for deepstack + self.deepstack_input_embeds = [ + torch.zeros(1, vllm_config.scheduler_config.max_num_batched_tokens, + thinker_config.text_config.hidden_size) + for _ in range(self.deepstack_num_level) + ] if self.use_deepstack else None + self.visual_dim = thinker_config.vision_config.out_hidden_size + self.multiscale_dim = self.visual_dim * self.deepstack_num_level + + def _get_deepstack_input_embeds(self) -> IntermediateTensors: + # get deepstack_input_embeds from buffer, and clear the buffer + if self.deepstack_input_embeds is None: + return None + return IntermediateTensors({ + f"deepstack_input_embeds_{idx}": + self.deepstack_input_embeds[idx] + for idx in range(self.deepstack_num_level) + }) + + def _set_deepstack_input_embeds( + self, deepstack_input_embeds: torch.Tensor) -> None: + if deepstack_input_embeds is None: + self.deepstack_input_embeds = None + return + + self.deepstack_input_embeds = [ + deepstack_input_embeds[idx].clone() + for idx in range(self.deepstack_num_level) + ] + + def _clear_deepstack_input_embeds(self, num_tokens: int) -> None: + # clear deepstack_input_embeds in buffer + if num_tokens > 0: + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[idx][:num_tokens].zero_() + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if (input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality): + mm_input_by_modality["image"] = ( + self._parse_and_validate_image_input(**kwargs)) + if (input_key in ("pixel_values_videos", "video_embeds") + and "video" not in mm_input_by_modality): + mm_input_by_modality["video"] = ( + self._parse_and_validate_video_input(**kwargs)) + if (input_key in ("input_audio_features") + and "audio" not in mm_input_by_modality): + mm_input_by_modality["audio"] = ( + self._parse_and_validate_audio_input(**kwargs)) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + if modality == "audio": + audio_embeddings = self._process_audio_input(multimodal_input) + multimodal_embeddings += audio_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + deepstack_input_embeds = None + if multimodal_embeddings is not None and len( + multimodal_embeddings) != 0: + # TODO (ywang96): support overlapping modalitiy embeddings so that + # `use_audio_in_video` will work on V1. + # split the feat dim to obtain multi-scale visual feature + if self.visual.deepstack_visual_indexes is not None: + multiscale_len = len(self.visual.deepstack_visual_indexes) + multimodal_embeddings_multiscale = [] + for index, embeddings in enumerate(multimodal_embeddings): + if embeddings.shape[ + -1] != self.config.text_config.hidden_size: + visual_dim = embeddings.shape[-1] // (multiscale_len + + 1) + main_dim, multi_dim = (visual_dim, + visual_dim * multiscale_len) + embeddings_main, embeddings_multiscale = torch.split( + embeddings, [main_dim, multi_dim], dim=-1) + multimodal_embeddings[index] = embeddings_main + multimodal_embeddings_multiscale.append( + embeddings_multiscale) + if len(multimodal_embeddings_multiscale) > 0: + deepstack_input_embeds = inputs_embeds.new_zeros( + inputs_embeds.size(0), inputs_embeds.shape[1], + multiscale_len * inputs_embeds.size(1)) + deepstack_input_embeds = merge_multimodal_embeddings( + input_ids, + deepstack_input_embeds, + multimodal_embeddings_multiscale, + placeholder_token_id=[ + self.config.image_token_id, + self.config.video_token_id + ], + ) + deepstack_input_embeds = deepstack_input_embeds.view( + inputs_embeds.shape[0], inputs_embeds.shape[1], + multiscale_len * visual_dim) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + [ + self.config.image_token_id, + self.config.video_token_id, + self.config.audio_token_id, + ], + ) + # self._set_deepstack_input_embeds(deepstack_input_embeds) + if deepstack_input_embeds is None: + deepstack_input_embeds = torch.zeros(inputs_embeds.size(0), + inputs_embeds.size(1), + multiscale_len * visual_dim, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype) + inputs_embeds = torch.cat([inputs_embeds, deepstack_input_embeds], + dim=2) + return inputs_embeds + + def get_multimodal_embeddings_v0( + self, **kwargs: object) -> Optional[NestedTensors]: + audio_input = self._parse_and_validate_audio_input(**kwargs) + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if audio_input is None and image_input is None and video_input is None: + return None + + multimodal_embeddings: list[tuple[NestedTensors, str]] = [] + + if audio_input is not None: + audio_embeds = self._process_audio_input(audio_input) + multimodal_embeddings.append((audio_embeds, "audio")) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + multimodal_embeddings.append((image_embeds, "image")) + if video_input is not None: + video_embeds = self._process_video_input(video_input) + multimodal_embeddings.append((video_embeds, "video")) + return multimodal_embeddings + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings( + input_ids)[:, :, :self.text_dim] + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + # self._set_deepstack_input_embeds(None) + return inputs_embeds + + use_deepstack = self.visual.deepstack_visual_indexes is not None + if use_deepstack: + multiscale_len = len(self.visual.deepstack_visual_indexes) + visual_dim = self.text_dim + deepstack_input_embeds = None + for embeddings, modality in multimodal_embeddings: + if modality in ["image", "video"]: + deepstack_input_embeds = torch.zeros_like( + inputs_embeds).unsqueeze(2).repeat( + 1, 1, len(self.visual.deepstack_visual_indexes), + 1).flatten(2) + break + + for embeddings, modality in multimodal_embeddings: + if modality == "audio": + placeholder_token_id = self.config.audio_token_id + if modality == "image": + placeholder_token_id = self.config.image_token_id + if modality == "video": + placeholder_token_id = self.config.video_token_id + if use_deepstack and modality in ["image", "video"]: + embeddings = torch.cat(embeddings) + embeddings, embeddings_multiscale = embeddings.split( + [visual_dim, visual_dim * multiscale_len], dim=-1) + deepstack_input_embeds = merge_multimodal_embeddings( + input_ids, + deepstack_input_embeds, + embeddings_multiscale, + placeholder_token_id=placeholder_token_id, + ) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, embeddings, placeholder_token_id) + + if use_deepstack and deepstack_input_embeds is not None: + # deepstack_input_embeds = deepstack_input_embeds.view( + # inputs_embeds.shape[0], inputs_embeds.shape[1], + # multiscale_len, visual_dim).permute(2, 0, 1, 3).contiguous() + # self._set_deepstack_input_embeds(deepstack_input_embeds) + inputs_embeds = torch.cat((inputs_embeds, deepstack_input_embeds), + dim=-1) + + return inputs_embeds + + # self._set_deepstack_input_embeds(None) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + deepstack_input_embeds = None + if self.use_deepstack and inputs_embeds is not None and get_pp_group( + ).is_first_rank and inputs_embeds.size(2) > self.text_dim: + multiscale_len = len(self.visual.deepstack_visual_indexes) + input_embeds_multiscale = inputs_embeds[:, :, self.text_dim:] + inputs_embeds = inputs_embeds[:, :, :self.text_dim] + input_embeds_multiscale = input_embeds_multiscale.view( + inputs_embeds.shape[0], inputs_embeds.shape[1], multiscale_len, + self.text_dim).permute(2, 0, 1, 3).contiguous() + deepstack_input_embeds = IntermediateTensors({ + f"deepstack_input_embeds_{idx}": + input_embeds_multiscale[idx] + for idx in range(self.deepstack_num_level) + }) + # input_embeds_multiscale = self._get_deepstack_input_embeds() + else: + deepstack_input_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + # args for deepstack + deepstack_input_embeds=deepstack_input_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=["talker.", "code2wav."], + ) + loaded_weights = loader.load_weights(weights, + mapper=self.hf_to_vllm_mapper) + + return loaded_weights diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py new file mode 100644 index 000000000000..417e2d847a6c --- /dev/null +++ b/vllm/model_executor/models/qwen3_vl.py @@ -0,0 +1,1647 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import BatchFeature +from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast +from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize +from transformers.models.qwen3_vl import (Qwen3VLProcessor, + Qwen3VLVideoProcessor) +from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLConfig, Qwen3VLVisionConfig) +from transformers.video_utils import VideoMetadata + +from vllm.attention.layer import check_upstream_fa_availability +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems, VideoItem) +from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import _Backend, current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import uses_mrope +from vllm.utils import is_list_of + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .qwen2_5_vl import (Qwen2_5_VisionAttention, + Qwen2_5_VisionRotaryEmbedding, + Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs, + Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs, + Qwen2_5_VLVideoPixelInputs) +from .qwen2_vl import Qwen2VLProcessingInfo +from .qwen3 import Qwen3ForCausalLM, Qwen3Model +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + maybe_prefix, merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) +is_hpu = current_platform.is_hpu() + +if is_hpu: + import habana_frameworks.torch as htorch + import habana_frameworks.torch.core as htcore + + +class Qwen3_VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d(in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.hidden_size) + return x + + +class Qwen3_VisionMLP(nn.Module): + + def __init__(self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.linear_fc1 = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc1") + self.linear_fc2 = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.linear_fc2") + self.act_fn = act_fn + + def forward(self, x: torch.Tensor): + mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return mlp_output + + +class Qwen3_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.mlp = Qwen3_VisionMLP(dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + attn_mask=attn_mask) + + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen3_VisionPatchMerger(nn.Module): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + spatial_merge_size: int = 2, + use_postshuffle_norm: bool = False, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + + self.use_postshuffle_norm = use_postshuffle_norm + if self.use_postshuffle_norm: + context_dim = self.hidden_size + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = norm_layer( + self.hidden_size if use_postshuffle_norm else context_dim) + self.linear_fc1 = ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc1") + self.act_fn = nn.GELU() + self.linear_fc2 = RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_fc2") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_postshuffle_norm: + x = self.norm(x.view(-1, self.hidden_size)) + else: + x = self.norm(x).view(-1, self.hidden_size) + + x_parallel, _ = self.linear_fc1(x) + x_parallel = self.act_fn(x_parallel) + out, _ = self.linear_fc2(x_parallel) + return out + + +class Qwen3_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen3VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + self.num_position_embeddings = vision_config.num_position_embeddings + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.temporal_patch_size = vision_config.temporal_patch_size + self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + + self.patch_embed = Qwen3_VisionPatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_channels=vision_config.in_channels, + hidden_size=self.hidden_size, + ) + + self.pos_embed = nn.Embedding(self.num_position_embeddings, + self.hidden_size) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Qwen3_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(vision_config.depth) + ]) + + self.merger = Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + + self.deepstack_merger_list = nn.ModuleList([ + Qwen3_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + spatial_merge_size=self.spatial_merge_size, + use_postshuffle_norm=True, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.deepstack_merger_list.{layer_idx}") + for layer_idx in range(len(self.deepstack_visual_indexes)) + ]) + + self.attn_backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): + num_grid_per_side = int(self.num_position_embeddings**0.5) + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in grid_thw: + h_idxs = torch.linspace(0, + num_grid_per_side - 1, + h, + dtype=torch.float32) + w_idxs = torch.linspace(0, + num_grid_per_side - 1, + w, + dtype=torch.float32) + + h_idxs_floor = h_idxs.to(torch.long) + w_idxs_floor = w_idxs.to(torch.long) + h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1, + max=num_grid_per_side - 1) + w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1, + max=num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T + + w_idxs_floor[None]).flatten().tolist() * t) + idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T + + w_idxs_ceil[None]).flatten().tolist() * t) + idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T + + w_idxs_floor[None]).flatten().tolist() * t) + idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T + + w_idxs_ceil[None]).flatten().tolist() * t) + + weight_list[0].extend( + ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t) + weight_list[1].extend( + ((1 - dh)[None].T * dw[None]).flatten().tolist() * t) + weight_list[2].extend( + (dh[None].T * (1 - dw)[None]).flatten().tolist() * t) + weight_list[3].extend( + (dh[None].T * dw[None]).flatten().tolist() * t) + + device = self.pos_embed.weight.device + dtype = self.pos_embed.weight.dtype + + p0 = self.pos_embed( + torch.tensor( + idx_list[0], dtype=torch.long, device=device)) * torch.tensor( + weight_list[0], dtype=dtype, device=device)[:, None] + p1 = self.pos_embed( + torch.tensor( + idx_list[1], dtype=torch.long, device=device)) * torch.tensor( + weight_list[1], dtype=dtype, device=device)[:, None] + p2 = self.pos_embed( + torch.tensor( + idx_list[2], dtype=torch.long, device=device)) * torch.tensor( + weight_list[2], dtype=dtype, device=device)[:, None] + p3 = self.pos_embed( + torch.tensor( + idx_list[3], dtype=torch.long, device=device)) * torch.tensor( + weight_list[3], dtype=dtype, device=device)[:, None] + + patch_pos_embeds = p0 + p1 + p2 + p3 + patch_pos_embeds = patch_pos_embeds.split( + [t * h * w for t, h, w in grid_thw]) + patch_pos_embeds_permute = [] + m_size = self.spatial_merge_size + for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): + pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size, + m_size, -1).permute(0, 1, 3, 2, 4, + 5).flatten(0, 4) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def compute_attn_mask_seqlen( + self, + cu_seqlens: torch.Tensor, + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: list[list[int]], + ) -> torch.Tensor: + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=grid_thw.dtype + if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + hidden_states = hidden_states.unsqueeze(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index( + layer_num) + deepstack_feature = self.deepstack_merger_list[ + deepstack_merger_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat( + [hidden_states] + deepstack_feature_lists, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("attn.qkv.", "attn.q.", "q"), + ("attn.qkv.", "attn.k.", "k"), + ("attn.qkv.", "attn.v.", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3_VisionTransformerStaticShape(Qwen3_VisionTransformer): + """ + Here we overwrite some of the methods of Qwen3_VisionTransformer + to make the model more friendly to static shapes. Specifically, + we split the forward method into: + - pre_attn (dynamic) + - forward (static shape) + and we should call get_image_embeds instead of forward, allowing + the forward method ro run with HPU_Graphs, whereas the + pre_attn and post_attn methods are allow to be dynamic. + """ + + def pad_multimodal_data(self, + pixel_values, + vision_buckets, + constant_value=0): + + desired_number_of_pixels = vision_buckets.get_multimodal_bucket( + pixel_values.shape[0]) + padding_len = desired_number_of_pixels - pixel_values.shape[0] + if padding_len <= 0: + return pixel_values + + logger_msg = "Padding current number pixel " \ + + str(pixel_values.shape[0]) \ + + " to " \ + + str(desired_number_of_pixels) + logger.debug(logger_msg) + + pixel_values = F.pad(pixel_values, (0, 0, 0, padding_len), "constant", + constant_value) + + return pixel_values + + def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor, + vision_buckets): + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + rotary_pos_emb = self.rot_pos_emb(grid_thw) + attention_mask = torch.ones(hidden_states.size(0), + 1).to(device=self.device) + + hidden_states = self.pad_multimodal_data(hidden_states, vision_buckets, + 0) + rotary_pos_emb = self.pad_multimodal_data(rotary_pos_emb, + vision_buckets, -100) + attention_mask = self.pad_multimodal_data(attention_mask, + vision_buckets, 0) + + return hidden_states, rotary_pos_emb, attention_mask + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor], + rotary_pos_emb: torch.Tensor, + ) -> torch.Tensor: + hidden_states = x.unsqueeze(1) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, + rotary_pos_emb=rotary_pos_emb, + attn_mask=attn_mask, + cu_seqlens=None) + if layer_num in self.deepstack_visual_indexes: + deepstack_merger_idx = self.deepstack_visual_indexes.index( + layer_num) + deepstack_feature = self.deepstack_merger_list[ + deepstack_merger_idx](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + hidden_states = self.merger(hidden_states) + hidden_states = torch.cat( + [hidden_states] + deepstack_feature_lists, + dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)] + return hidden_states + + def get_image_embeds( + self, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + vision_buckets, + ) -> torch.Tensor: + + offset = 0 + results = [] + # process each image one by one + for img_idx in range(grid_thw.shape[0]): + img_shape = grid_thw[img_idx, :].unsqueeze(0).clone() + # For video, we process frames separately + grid_t = grid_thw[img_idx, 0] + img_shape[0, 0] = 1 + curr_img_size = img_shape.prod() + for _ in torch.arange(0, grid_t): + pixel_values_curr_img = pixel_values[offset:offset + + curr_img_size, :] + + offset += curr_img_size + + (pixel_values_curr_img_padded, rot_pos_emb, + attention_mask) = self.pre_attn(pixel_values_curr_img, + img_shape, vision_buckets) + + fullatt_block_attn_mask = \ + attention_mask.squeeze(1).unsqueeze(0) * attention_mask + + extra_forward_kwargs = {} + if htorch.utils.internal.is_lazy(): + padded_len = pixel_values_curr_img_padded.shape[0] + use_graph = vision_buckets.use_graph(padded_len) + extra_forward_kwargs.update( + {"bypass_hpu_graphs": not use_graph}) + + htcore.mark_step() + hidden_states = self.forward(pixel_values_curr_img_padded, + rotary_pos_emb=rot_pos_emb, + attn_mask=fullatt_block_attn_mask, + **extra_forward_kwargs) + htcore.mark_step() + + post_embed_size = curr_img_size // self.spatial_merge_unit + results += [hidden_states[:post_embed_size, :]] + + results_cat = torch.concat(results) + image_embeds = results_cat + return image_embeds + + +class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3VLConfig) + + def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor: + return self.ctx.get_hf_processor( + Qwen3VLProcessor, + use_fast=kwargs.pop("use_fast", True), + **kwargs, + ) + + def get_tokenizer(self): + return self.ctx.tokenizer + + def get_image_processor(self, + **kwargs: object) -> Qwen2VLImageProcessorFast: + return self.get_hf_processor(**kwargs).image_processor + + def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor: + return self.get_hf_processor(**kwargs).video_processor + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 2, + do_resize: bool = True, + image_processor: Optional[Qwen2VLImageProcessorFast], + ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + temporal_patch_size = vision_config.temporal_patch_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.size["shortest_edge"], + max_pixels=image_processor.size["longest_edge"], + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + padded_num_frames = num_frames + num_frames % temporal_patch_size + + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (merge_size**2) + + return preprocessed_size, num_vision_tokens + + def _calculate_timestamps(self, indices: list[int] | torch.Tensor, + video_fps: float, merge_size: int): + if not isinstance(indices, list): + indices = indices.tolist() + if len(indices) % merge_size != 0: + # don't update metadata's frames_indices directly + indices = indices + [indices[-1] + ] * (merge_size - len(indices) % merge_size) + timestamps = [idx / video_fps for idx in indices] + timestamps = [(timestamps[i] + timestamps[i + merge_size - 1]) / 2 + for i in range(0, len(timestamps), merge_size)] + return timestamps + + def _get_video_second_idx( + self, + metadata: dict[str, Any], + do_sample_frames: Optional[bool] = None, + sampled_fps: Optional[float] = None) -> list[int]: + video_processor = self.get_video_processor() + merge_size = video_processor.merge_size + indices = metadata["frames_indices"] + + # metadata["fps"] refers to the true fps of the input video. + video_fps = metadata["fps"] + if do_sample_frames is None: + do_sample_frames = metadata.get("do_sample_frames", False) + + # If video frames are sampled in HF processor (instead of vLLM + # video loader), we need to re-calculate the indices from original + # metadata. + if do_sample_frames: + # here video_fps is the fps of the sampled video, and + # metadata["fps"] refers to the fps of the original video. + video_fps = sampled_fps if sampled_fps else video_processor.fps + total_num_frames = metadata["total_num_frames"] + num_frames = int(total_num_frames / metadata["fps"] * video_fps) + num_frames = min( + min(max(num_frames, video_processor.min_frames), + video_processor.max_frames), total_num_frames) + indices = np.linspace(0, total_num_frames - 1, + num_frames).round().astype(int).tolist() + timestamps = self._calculate_timestamps(indices, video_fps, merge_size) + return timestamps + + +class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + image_token = "<|vision_start|><|image_pad|><|vision_end|>" + video_token = "<|vision_start|><|video_pad|><|vision_end|>" + + return image_token * num_images + video_token * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = ( + self.info.get_image_size_with_most_features()) + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts) + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ), + } + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + ) -> list[VideoItem]: + num_frames = max(num_frames, 2) + video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) + video_items = [] + for i in range(num_videos): + video_metadata = { + "fps": 2.0, + "duration": num_frames / 2.0, + "total_num_frames": num_frames, + "frames_indices": [i for i in range(num_frames)], + "video_backend": "opencv", + "do_sample_frames": False, + } + video_item = (video.copy(), video_metadata) + video_items.append(video_item) + return video_items + + +class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo] + ): + + def _get_data_parser(self) -> MultiModalDataParser: + return MultiModalDataParser(video_needs_metadata=True) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + processor = self.info.get_hf_processor(**mm_kwargs) + + # Separate video processing from image processing. Because the videos + # are processed into serval image patches + if ("videos" in mm_data and isinstance(mm_data["videos"], list) + and len(mm_data["videos"]) > 0): + video_grid_thw_lst = [] + pixel_values_videos_lst = [] + + for item_idx, item in enumerate(mm_data.pop("videos", [])): + video_array, metadata = item + + # NOTE: @JJJYmmm new attr metadata.frames_indices indicates + # the sampled frames indices of pre-sampled videos, which is + # used to calculate the timestamps. Make sure that + # do_sample_frames in mm_kwargs is false for presampled videos. + + # NOTE: a copy of is created to update do_sample_frames, + # otherwise mm_hash for the object will be incorrect. + video_mm_kwargs = dict(**mm_kwargs) + if "do_sample_frames" not in video_mm_kwargs: + # qwen_vl_utils already has "do_sample_frames" in + # mm_kwargs, don't overwrite it. + video_mm_kwargs["do_sample_frames"] = metadata.get( + "do_sample_frames", False) + + metadata = VideoMetadata(**{ + k: metadata[k] + for k in metadata if k != "do_sample_frames" + }) + + video_mm_data = dict() + video_mm_data["videos"] = [[video_array]] + video_mm_data["video_metadata"] = [[metadata]] + + video_outputs = super()._call_hf_processor( + prompt="<|vision_start|><|video_pad|><|vision_end|>", + mm_data=video_mm_data, + mm_kwargs=video_mm_kwargs, + ) + input_ids = video_outputs.pop("input_ids") + video_placeholder = processor.tokenizer.batch_decode( + input_ids)[0] + prompt = prompt.replace( + "<|vision_start|><|video_pad|><|vision_end|>", + video_placeholder, + 1, + ) + + video_grid_thw_lst.append(video_outputs["video_grid_thw"]) + pixel_values_videos_lst.append( + video_outputs["pixel_values_videos"]) + video_outputs = dict( + pixel_values_videos=torch.cat(pixel_values_videos_lst), + video_grid_thw=torch.cat(video_grid_thw_lst), + ) + else: + video_outputs = dict() + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + combined_outputs = dict( + processed_outputs, + **video_outputs, + ) + return BatchFeature(combined_outputs) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + hf_config = self.info.get_hf_config() + + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + vision_end_token_id = hf_config.vision_end_token_id + + merge_length = image_processor.merge_size**2 + + def get_image_replacement_qwen3vl(item_idx: int): + grid_thw = out_mm_kwargs["image_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [hf_processor.image_token_id] * num_tokens + + def get_video_replacement_qwen3vl(item_idx: int): + grid_thw = out_mm_kwargs["video_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + + video, metadata = mm_items["video"][item_idx] + do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames") + sampled_fps = hf_processor_mm_kwargs.get("fps") + if is_list_of(sampled_fps, float): + sampled_fps = sampled_fps[item_idx] + timestamps = self.info._get_video_second_idx( + metadata, do_sample_frames, sampled_fps) + + assert len(timestamps) == grid_thw[0], ( + f"The timestamps length({len(timestamps)}) should be equal " + f"video length ({grid_thw[0]}).") + + frames_idx_token = [ + tokenizer.encode(f"<{curr_time:.1f} seconds>", + add_special_tokens=False) + for curr_time in timestamps + ] + num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + placeholder = [] + for frame_idx in frames_idx_token: + placeholder.extend(frame_idx) + placeholder.extend([vision_start_token_id] + + [video_token_id] * num_tokens_per_frame + + [vision_end_token_id]) + return PromptUpdateDetails.select_token_id(placeholder, + video_token_id) + + return [ + PromptReplacement( + modality="image", + target=hf_processor.image_token, + replacement=get_image_replacement_qwen3vl, + ), + + # NOTE: We match string on purpose since searching sequence of + # token ids takes more time. + PromptReplacement( + modality="video", + target="<|vision_start|><|video_pad|><|vision_end|>", + replacement=get_video_replacement_qwen3vl, + ), + ] + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + # the same shape as input_embeds + "deepstack_input_embeds": 0 + }) +class Qwen3LLMModel(Qwen3Model): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + if not get_pp_group().is_first_rank: + assert self.start_layer >= len( + vllm_config.model_config.hf_config.vision_config. + deepstack_visual_indexes), ( + "start_layer should be greater than or equal to " + "len(deepstack_visual_indexes)") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + # args for deepstack + deepstack_input_embeds: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + if is_hpu: + htorch.core.mark_step() + for layer_idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + layer_idx = layer_idx + self.start_layer + + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if deepstack_input_embeds is not None and \ + layer_idx in range(0, len(deepstack_input_embeds)): + hidden_states = hidden_states + deepstack_input_embeds[ + f"deepstack_input_embeds_{layer_idx}"] + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3LLMForCausalLM(Qwen3ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3ForCausalLM, self).__init__() + config = vllm_config.model_config.hf_config.text_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen3LLMModel(vllm_config=vllm_config, prefix=prefix) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix="lm_head") + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, + info=Qwen3VLProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder) +class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.visual.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.language_model.": "language_model.model.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|vision_start|><|image_pad|><|vision_end|>" + if modality.startswith("video"): + return "<|vision_start|><|video_pad|><|vision_end|>" + + raise ValueError("Only image or video modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): + super().__init__() + config: Qwen3VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self.text_dim = config.text_config.hidden_size + + if is_hpu: + qwen3_visionTransformer = Qwen3_VisionTransformerStaticShape + else: + qwen3_visionTransformer = Qwen3_VisionTransformer + + self.visual = qwen3_visionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + + self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, + "language_model")) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + self.use_deepstack = hasattr(config.vision_config, + 'deepstack_visual_indexes') + self.deepstack_num_level = len( + config.vision_config.deepstack_visual_indexes + ) if self.use_deepstack else 0 + # register buffer for deepstack + self.deepstack_input_embeds = [ + torch.zeros(1, vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.hidden_size) + for _ in range(self.deepstack_num_level) + ] if self.use_deepstack else None + + def _get_deepstack_input_embeds(self) -> IntermediateTensors: + # get deepstack_input_embeds from buffer, and clear the buffer + return IntermediateTensors({ + f"deepstack_input_embeds_{idx}": + self.deepstack_input_embeds[idx] + for idx in range(self.deepstack_num_level) + }) + + def _set_deepstack_input_embeds( + self, deepstack_input_embeds: torch.Tensor) -> None: + self.deepstack_input_embeds = [ + deepstack_input_embeds[idx].clone() + for idx in range(self.deepstack_num_level) + ] + + def _clear_deepstack_input_embeds(self, batch_size: int, + num_tokens: int) -> None: + # clear deepstack_input_embeds in buffer + if num_tokens > 0: + for idx in range(self.deepstack_num_level): + self.deepstack_input_embeds[ + idx][:batch_size, :num_tokens].zero_() + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + if is_hpu: + assert isinstance(self.visual, + Qwen3_VisionTransformerStaticShape) + image_embeds = self.visual.get_image_embeds( + pixel_values, + grid_thw=grid_thw, + vision_buckets=self.vision_buckets, + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync + merge_size = self.visual.spatial_merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + return image_embeds.split(sizes) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + if is_hpu: + assert isinstance(self.visual, + Qwen3_VisionTransformerStaticShape) + video_embeds = self.visual.get_image_embeds( + pixel_values_videos, + grid_thw=grid_thw, + vision_buckets=self.vision_buckets, + ) + else: + video_embeds = self.visual(pixel_values_videos, + grid_thw=grid_thw) + + # Split concatenated embeddings for each video item. + # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync + merge_size = self.visual.spatial_merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + return video_embeds.split(sizes) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + mm_input_by_modality = {} + for input_key in kwargs: + if input_key in ("pixel_values", "image_embeds" + ) and "image" not in mm_input_by_modality: + mm_input_by_modality[ + "image"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos", "video_embeds" + ) and "video" not in mm_input_by_modality: + mm_input_by_modality[ + "video"] = self._parse_and_validate_video_input(**kwargs) + return mm_input_by_modality + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + mm_input_by_modality = self._parse_and_validate_multimodal_inputs( + **kwargs) + if not mm_input_by_modality: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in mm_input_by_modality: + multimodal_input = mm_input_by_modality[modality] + if modality == "image": + vision_embeddings = self._process_image_input(multimodal_input) + multimodal_embeddings += vision_embeddings + if modality == "video": + video_embeddings = self._process_video_input(multimodal_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def _compute_deepstack_embeds( + self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor: + visual_lens = [ + x.shape[0] if isinstance(x, torch.Tensor) else len(x) + for x in multimodal_embeddings + ] + multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0) + + visual_dim = multimodal_embeddings_cat.shape[-1] // ( + self.deepstack_num_level + 1) + + main_dim, multi_dim = visual_dim, visual_dim * self.deepstack_num_level + multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501 + multimodal_embeddings_cat, [main_dim, multi_dim], + dim=-1) + + multimodal_embeddings = torch.split(multimodal_embeddings_main, + visual_lens, + dim=0) + multimodal_embeddings_multiscale = torch.split( + multimodal_embeddings_multiscale, visual_lens, dim=0) + + deepstack_input_embeds = inputs_embeds.new_zeros( + inputs_embeds.size(0), inputs_embeds.size(1), + self.deepstack_num_level * inputs_embeds.size(1)) + + deepstack_input_embeds = merge_multimodal_embeddings( + input_ids, + deepstack_input_embeds, + multimodal_embeddings_multiscale, + placeholder_token_id=[ + self.config.image_token_id, self.config.video_token_id + ], + ) + deepstack_input_embeds = deepstack_input_embeds.view( + inputs_embeds.shape[0], inputs_embeds.shape[1], + self.deepstack_num_level, visual_dim).contiguous() + deepstack_input_embeds = deepstack_input_embeds.permute( + 2, 0, 1, 2).contiguous() + return deepstack_input_embeds, multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + deepstack_input_embeds = None + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None and self.use_deepstack: + deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501 + input_ids, inputs_embeds, multimodal_embeddings) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) + + # if self.use_deepstack: + # if deepstack_input_embeds is None: + # deepstack_input_embeds = torch.zeros_like( + # inputs_embeds).unsqueeze(0).repeat( + # self.deepstack_num_level, 1, 1, 1).contiguous() + # self._set_deepstack_input_embeds(deepstack_input_embeds) + if deepstack_input_embeds is None: + deepstack_input_embeds = torch.zeros(inputs_embeds.size(0), + inputs_embeds.size(1), + self.deepstack_num_level * + self.text_dim, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype) + inputs_embeds = torch.cat((inputs_embeds, deepstack_input_embeds), + dim=-1) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Qwen2_5_VLImageInputs] = None, + video_input: Optional[Qwen2_5_VLVideoInputs] = None, + ) -> torch.Tensor: + inputs_embeds = self.get_input_embeddings( + input_ids)[:, :, :self.text_dim] + + if self.use_deepstack: + visual_dim = self.text_dim + deepstack_input_embeds = None + if image_input is not None or video_input is not None: + deepstack_input_embeds = torch.zeros_like( + inputs_embeds).unsqueeze(2).repeat( + 1, 1, self.deepstack_num_level, 1).flatten(2) + + if image_input is not None: + image_embeds = self._process_image_input(image_input) + if self.use_deepstack: + image_embeds = torch.cat(image_embeds) + + image_embeds, image_embeds_multiscale = image_embeds.split( + [visual_dim, visual_dim * self.deepstack_num_level], + dim=-1) + + deepstack_input_embeds = merge_multimodal_embeddings( + input_ids, + deepstack_input_embeds, + image_embeds_multiscale, + placeholder_token_id=self.config.image_token_id, + ) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + if self.use_deepstack: + video_embeds = torch.cat(video_embeds) + + video_embeds, video_embeds_multiscale = video_embeds.split( + [visual_dim, visual_dim * self.deepstack_num_level], + dim=-1) + + deepstack_input_embeds = merge_multimodal_embeddings( + input_ids, + deepstack_input_embeds, + video_embeds_multiscale, + placeholder_token_id=self.config.video_token_id, + ) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + + if self.use_deepstack and deepstack_input_embeds is not None: + # deepstack_input_embeds = deepstack_input_embeds.view( + # inputs_embeds.shape[0], inputs_embeds.shape[1], + # self.deepstack_num_level, visual_dim).permute(2, 0, 1, + # 3).contiguous() + # self._set_deepstack_input_embeds(deepstack_input_embeds) + inputs_embeds = torch.cat((inputs_embeds, deepstack_input_embeds), + dim=-1) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Qwen3VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen3VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + pixel_values_videos: Pixel values of videos to be fed to a model. + `None` if no videos are passed. + video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. + `None` if no videos are passed. + """ + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + + deepstack_input_embeds = None + if self.use_deepstack and inputs_embeds is not None and get_pp_group( + ).is_first_rank and inputs_embeds.size(2) > self.text_dim: + multiscale_len = len(self.visual.deepstack_visual_indexes) + input_embeds_multiscale = inputs_embeds[:, :, self.text_dim:] + inputs_embeds = inputs_embeds[:, :, :self.text_dim] + input_embeds_multiscale = input_embeds_multiscale.view( + inputs_embeds.shape[0], inputs_embeds.shape[1], multiscale_len, + self.text_dim).permute(2, 0, 1, 3).contiguous() + deepstack_input_embeds = IntermediateTensors({ + f"deepstack_input_embeds_{idx}": + input_embeds_multiscale[idx] + for idx in range(self.deepstack_num_level) + }) + # deepstack_input_embeds = self._get_deepstack_input_embeds() + else: + deepstack_input_embeds = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + # args for deepstack + deepstack_input_embeds=deepstack_input_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="model.visual.merger", + tower_model="model.visual.", + ) diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py new file mode 100644 index 000000000000..04f87a078cf5 --- /dev/null +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -0,0 +1,356 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights.""" +import typing +from collections.abc import Iterable +from typing import Callable, Optional, Union + +import torch +from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import ( + Qwen3VLMoeConfig) + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors + +from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel +from .qwen3_vl import (Qwen3_VisionTransformer, + Qwen3_VisionTransformerStaticShape, + Qwen3VLDummyInputsBuilder, + Qwen3VLForConditionalGeneration, + Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo) +from .utils import is_pp_missing_parameter, maybe_prefix + +logger = init_logger(__name__) +is_hpu = current_platform.is_hpu() + +if is_hpu: + import habana_frameworks.torch as htorch + + +class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen3VLMoeConfig) + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + # the same shape as input_embeds + "deepstack_input_embeds": 0 + }) +class Qwen3MoeLLMModel(Qwen3MoeModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + if not get_pp_group().is_first_rank: + assert self.start_layer >= len( + vllm_config.model_config.hf_config.vision_config. + deepstack_visual_indexes), ( + "start_layer should be greater than or equal to " + "len(deepstack_visual_indexes)") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + deepstack_input_embeds: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + if is_hpu: + htorch.core.mark_step() + for layer_idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + layer_idx = layer_idx + self.start_layer + + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if deepstack_input_embeds is not None and \ + layer_idx in range(0, len(deepstack_input_embeds)): + hidden_states = hidden_states + deepstack_input_embeds[ + f"deepstack_input_embeds_{layer_idx}"] + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_fused_expert_weights(self, name: str, params_dict: dict, + loaded_weight: torch.Tensor, shard_id: str, + num_experts: int): + param = params_dict[name] + weight_loader = typing.cast(Callable[..., bool], param.weight_loader) + for expert_id in range(num_experts): + curr_expert_weight = loaded_weight[expert_id] + weight_loader(param, curr_expert_weight, name, shard_id, expert_id) + return True + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + # Skip loading extra parameters for GPTQ/modelopt models. + ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", + ".v_scale", "_v_scale", ".weight_scale", + "_weight_scale", ".input_scale", "_input_scale") + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + is_fused_expert = False + fused_expert_params_mapping = [ + ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"), + ("experts.w2_weight", "experts.down_proj", 0, "w2"), + ] + num_experts = self.config.num_experts + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if ("experts.gate_up_proj" in name + or "experts.down_proj" in name): + is_fused_expert = True + expert_params_mapping = fused_expert_params_mapping + + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith(ignore_suffixes) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + if is_fused_expert: + loaded_weight = loaded_weight.transpose(-1, + -2) # no bias + if "experts.gate_up_proj" in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + self.load_fused_expert_weights( + name_mapped, params_dict, loaded_weight[0], + "w1", num_experts) + self.load_fused_expert_weights( + name_mapped, params_dict, loaded_weight[1], + "w3", num_experts) + else: + # down_proj + self.load_fused_expert_weights( + name_mapped, params_dict, loaded_weight, + shard_id, num_experts) + else: + if is_pp_missing_parameter(name_mapped, self): + continue + # Skip loading extra parameters for GPTQ/modelopt models + if name_mapped.endswith( + ignore_suffixes + ) and name_mapped not in params_dict: + continue + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id) + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra parameters for GPTQ/modelopt models. + if name.endswith( + ignore_suffixes) and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 + name, + remapped_kv_scale_name, + ) + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3MoeForCausalLM, self).__init__() + self.config = vllm_config.model_config.hf_config.text_config + self.quant_config = vllm_config.quant_config + self.model = Qwen3MoeLLMModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size, + quant_config=self.quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + +@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor, + info=Qwen3VLMoeProcessingInfo, + dummy_inputs=Qwen3VLDummyInputsBuilder) +class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super(Qwen3VLForConditionalGeneration, self).__init__() + config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self.text_dim = config.text_config.hidden_size + + if is_hpu: + qwen3_visionTransformer = Qwen3_VisionTransformerStaticShape + else: + qwen3_visionTransformer = Qwen3_VisionTransformer + + self.visual = qwen3_visionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + + self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, + "language_model")) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + self.use_deepstack = hasattr(config.vision_config, + 'deepstack_visual_indexes') + self.deepstack_num_level = len( + config.vision_config.deepstack_visual_indexes + ) if self.use_deepstack else 0 + # register buffer for deepstack + self.deepstack_input_embeds = [ + torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.hidden_size) + for _ in range(self.deepstack_num_level) + ] if self.use_deepstack else None diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 47b191007ee1..4ebd5e1638ce 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -220,9 +220,13 @@ "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501 - "UltravoxModel": ("ultravox", "UltravoxModel"), + "Qwen3OmniMoeForConditionalGeneration": ("qwen3_omni_moe_thinker", "Qwen3OmniMoeThinkerForConditionalGeneration"), # noqa: E501 + "Qwen3OmniMoeModel": ("qwen3_omni_moe_thinker", "Qwen3OmniMoeThinkerForConditionalGeneration"), # noqa: E501 + "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501 + "Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), # noqa: E501 "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 + "UltravoxModel": ("ultravox", "UltravoxModel"), # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index b56ad407ce55..fdcd04bdd6d3 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -7,11 +7,11 @@ from dataclasses import dataclass from functools import partial from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, - Union, cast, final) +from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, + cast, final) import numpy as np -from typing_extensions import NotRequired, TypeAlias +from typing_extensions import NotRequired, TypeAlias, TypeVar from vllm.jsontree import JSONTree, json_map_leaves from vllm.utils import LazyLoader, full_groupby, is_list_of @@ -601,6 +601,85 @@ def modality(self) -> str: return next(iter(modalities)) +_I = TypeVar( + "_I", + MultiModalKwargsItem, + Optional[MultiModalKwargsItem], + default=MultiModalKwargsItem, +) + + +class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): + """ + A dictionary of + [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s + by modality. + """ + + @staticmethod + def from_hf_inputs( + hf_inputs: "BatchFeature", + config_by_key: Mapping[str, MultiModalFieldConfig], + ): + # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` + # We assume that those fields are not used in vLLM + elems_by_key = dict[str, Sequence[MultiModalFieldElem]]() + keys_by_modality = defaultdict[str, set[str]](set) + for key, config in config_by_key.items(): + batch = hf_inputs.get(key) + if batch is not None: + elems = config.build_elems(key, batch) + if len(elems) > 0: + elems_by_key[key] = elems + keys_by_modality[config.modality].add(key) + + items = list[MultiModalKwargsItem]() + for modality, keys in keys_by_modality.items(): + elems_in_modality = {k: elems_by_key[k] for k in keys} + batch_sizes = {k: len(v) for k, v in elems_in_modality.items()} + + if len(set(batch_sizes.values())) > 1: + raise ValueError( + f"Cannot merge different batch sizes for {modality=}! " + f"Found: {batch_sizes=}") + + batch_size = next(iter(batch_sizes.values())) + for item_idx in range(batch_size): + elems = [v[item_idx] for v in elems_in_modality.values()] + items.append(MultiModalKwargsItem.from_elems(elems)) + + return MultiModalKwargsItems.from_seq(items) + + @staticmethod + def from_seq(items: Sequence[MultiModalKwargsItem]): + items_by_modality = full_groupby(items, key=lambda x: x.modality) + return MultiModalKwargsItems(items_by_modality) + + def __getitem__(self, modality: str) -> Sequence[_I]: + if modality not in self: + raise KeyError(f"Modality {modality!r} not found. " + f"Available modalities: {set(self.keys())}") + + return super().__getitem__(modality) # type: ignore[return-value] + + def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": + elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) + for modality, items in self.items(): + for i, item in enumerate(items): + if item is None: + raise RuntimeError("Cannot build data from empty " + f"mm_items[{modality}][{i}]") + + for key, elem in item.items(): + elems_by_key[key].append(elem) + + return MultiModalKwargs({ + key: + elems[0].field.reduce_data(elems, pin_memory=pin_memory) + for key, elems in elems_by_key.items() + }) + + # NOTE: UserDict is for V0 compatibility. # V1 should access individual items via `get_item`. class MultiModalKwargs(UserDict[str, NestedTensors]): diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index f15264cbde0b..903d3f49c6bd 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -146,9 +146,12 @@ def load_bytes(cls, # Use transformers transformers.video_utils.VideoMetadata format metadata = { "total_num_frames": total_frames_num, - "fps": original_fps, - "duration": duration, - "video_backend": "opencv" + "fps": num_frames / duration, + "video_backend": "opencv", + "frames_indices": list(range(num_frames)), + # extra field used to control hf processor's video + # sampling behavior + "do_sample_frames": num_frames == total_frames_num, } return frames, metadata diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 32b0c4d711b6..4e9a38b05f94 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -723,7 +723,9 @@ def compute_input_embeddings_for_mrope_mm_optimized( input_ids = kwargs['input_ids'] with compile_only_mode_context_false(): if self.model_is_mrope: - if self.model.config.model_type == 'qwen2_5_omni_thinker': + if self.model.config.model_type in [ + 'qwen2_5_omni_thinker', 'qwen3_omni_moe_thinker' + ]: multimodal_embeddings = \ self.model.get_multimodal_embeddings_v0(**kwargs) inputs_embeds = self.model.get_input_embeddings_v0( @@ -2801,9 +2803,15 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args, image_h = img_args // 8 image_grid_thw = torch.tensor( [[1, image_h, int(img_args / image_h)]]) + embed_dim = 1176 + if any([ + model_type in self.get_model().config.model_type + for model_type in ['qwen3_vl', "qwen3_omni"] + ]): + embed_dim = 1536 pixel_values = torch.randn( image_grid_thw[0].prod(), - 1176) # TODO: figure out the variable name + embed_dim) # TODO: figure out the variable name assert pixel_values.shape[0] % 64 == 0, ( f"pixel_values must be sliced in 64 chunks, "