diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index ec736aa236ff..f958a150a84b 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -562,7 +562,10 @@ Specified using `--task generate`.
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ |
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* |
+| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B`
+| `Qwen3VLForConditionalGeneration` | Qwen3-VL | T + IE+ + VE+ | `Qwen/Qwen3-VL-4B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `Qwen3VLMoeForConditionalGeneration` | Qwen3-VL-MOE | T + IE+ + VE+ | `Qwen/Qwen3-VL-30B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | | ✅︎ | ✅︎\* |
+| `Qwen3OmniMoeThinkerForConditionalGeneration` | Qwen3-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen3-Omni-30B-A3B-Instruct`, `Qwen/Qwen3-Omni-30B-A3B-Thinking` | ✅︎ | ✅︎ | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
| `TarsierForConditionalGeneration` | Tarsier | T + IE+ | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 825abeaf7e75..4bc622151d11 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -1076,6 +1076,116 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
)
+# Qwen3-VL-Dense
+def run_qwen3_vl(questions: list[str], modality: str) -> ModelRequestData:
+ model_name = "Qwen/Qwen3-VL-4B-Instruct"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=4096,
+ max_num_seqs=5,
+ mm_processor_kwargs={
+ "min_pixels": 28 * 28,
+ "max_pixels": 1280 * 28 * 28,
+ "fps": 1,
+ },
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ if modality == "image":
+ placeholder = "<|image_pad|>"
+ elif modality == "video":
+ placeholder = "<|video_pad|>"
+
+ prompts = [
+ (
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
+ f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
+ f"{question}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
+# Qwen3-VL-MOE
+def run_qwen3_vl_moe(questions: list[str], modality: str) -> ModelRequestData:
+ model_name = "Qwen/Qwen3-VL-30B-A3B-Instruct"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=4096,
+ max_num_seqs=5,
+ mm_processor_kwargs={
+ "min_pixels": 28 * 28,
+ "max_pixels": 1280 * 28 * 28,
+ "fps": 1,
+ },
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ if modality == "image":
+ placeholder = "<|image_pad|>"
+ elif modality == "video":
+ placeholder = "<|video_pad|>"
+
+ prompts = [
+ (
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
+ f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
+ f"{question}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
+def run_qwen3_omni_moe(questions: list[str], modality: str) -> ModelRequestData:
+ model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=4096,
+ max_num_seqs=5,
+ mm_processor_kwargs={
+ "min_pixels": 28 * 28,
+ "max_pixels": 1280 * 28 * 28,
+ "fps": 1,
+ },
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ if modality == "image":
+ placeholder = "<|image_pad|>"
+ elif modality == "video":
+ placeholder = "<|video_pad|>"
+
+ prompts = [
+ (
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
+ f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
+ f"{question}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -1146,11 +1256,20 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
"qwen2_5_omni": run_qwen2_5_omni,
+ "qwen3_vl": run_qwen3_vl,
+ "qwen3_vl_moe": run_qwen3_vl_moe,
+ "qwen3_omni_moe": run_qwen3_omni_moe,
"skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm,
"tarsier": run_tarsier,
}
+MODELS_NEED_VIDEO_METADATA = [
+ "glm4_1v",
+ "qwen3_vl",
+ "qwen3_vl_moe",
+]
+
def get_multi_modal_input(args):
"""
@@ -1176,12 +1295,13 @@ def get_multi_modal_input(args):
if args.modality == "video":
# Input video and question
+ needs_metadata = args.model_type in MODELS_NEED_VIDEO_METADATA
video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays
metadata = VideoAsset(name="baby_reading", num_frames=args.num_frames).metadata
vid_questions = ["Why is this video funny?"]
return {
- "data": [(video, metadata)] if args.model_type == "glm4_1v" else video,
+ "data": [(video, metadata)] if needs_metadata else video,
"questions": vid_questions,
}
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index e52a42a2e452..84f2b7cd368d 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -30,6 +30,7 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
"""
# Ensure video metadata is included
if "video" in mm_data:
+ # GLM4.1V doesn't support multiple videos
video = mm_data["video"]
mm_data["video"] = (video, {
"total_num_frames": len(video),
@@ -39,6 +40,33 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
})
return mm_data
+def qwen3_vl_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
+ """
+ Patch the multimodal data for Qwen3-VL model.
+ """
+
+ def create_metadata(frames: np.ndarray):
+ num_frames = len(frames)
+ return {
+ "total_num_frames": num_frames,
+ "fps": 2.0,
+ "duration": num_frames / 2.0,
+ "video_backend": "opencv",
+ "frames_indices": list(range(num_frames)),
+ "do_sample_frames": True,
+ }
+
+ # Ensure video metadata is included
+ if "video" in mm_data:
+ video = mm_data["video"]
+ if isinstance(video, list):
+ # multiple videos
+ mm_data["video"] = [(vid, create_metadata(vid)) for vid in video]
+ else:
+ # single video
+ mm_data["video"] = (video, create_metadata(video))
+ return mm_data
+
def _test_processing_correctness(
model_id: str,
@@ -171,8 +199,10 @@ def _test_processing_correctness(
}
MM_DATA_PATCHES = {
- # GLM4.1V requires video metadata to be included in the input
+ # GLM4.1V and Qwen3-VL requires video metadata to be included in the input
"glm4v": glm4_1v_patch_mm_data,
+ "qwen3_vl": qwen3_vl_patch_mm_data,
+ "qwen3_vl_moe": qwen3_vl_patch_mm_data,
}
@@ -304,6 +334,8 @@ def _test_processing_correctness_one(
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"Qwen/Qwen2.5-Omni-3B",
+ "Qwen/Qwen3-VL-4B-Instruct",
+ "Qwen/Qwen3-VL-30B-A3B-Instruct",
"Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
diff --git a/tests/models/registry.py b/tests/models/registry.py
index c90aed31f597..ac7545bd42c9 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -404,6 +404,15 @@ def check_available_online(
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
+ "Qwen3VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-4B-Instruct", # noqa: E501
+ max_model_len=4096,
+ min_transformers_version="4.57"), # noqa: E501
+ "Qwen3VLMoeForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-VL-30B-A3B-Instruct", # noqa: E501
+ max_model_len=4096,
+ min_transformers_version="4.57"),
+ "Qwen3OmniMoeThinkerForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen3-Omni-30B-A3B-Instruct", # noqa: E501
+ max_model_len=4096,
+ min_transformers_version="4.57"),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
diff --git a/vllm/assets/video.py b/vllm/assets/video.py
index 16412121cf0a..93a4733d46b5 100644
--- a/vllm/assets/video.py
+++ b/vllm/assets/video.py
@@ -77,7 +77,7 @@ def video_to_pil_images_list(path: str,
]
-def video_get_metadata(path: str) -> dict[str, Any]:
+def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise ValueError(f"Could not open video file {path}")
@@ -86,11 +86,18 @@ def video_get_metadata(path: str) -> dict[str, Any]:
fps = cap.get(cv2.CAP_PROP_FPS)
duration = total_frames / fps if fps > 0 else 0
+ if num_frames == -1 or num_frames > total_frames:
+ num_frames = total_frames
+
metadata = {
"total_num_frames": total_frames,
"fps": fps,
"duration": duration,
- "video_backend": "opencv"
+ "video_backend": "opencv",
+ "frames_indices": list(range(num_frames)),
+ # extra field used to control hf processor's video
+ # sampling behavior
+ "do_sample_frames": num_frames == total_frames,
}
return metadata
@@ -126,7 +133,7 @@ def np_ndarrays(self) -> npt.NDArray:
@property
def metadata(self) -> dict[str, Any]:
video_path = download_video_asset(self.filename)
- ret = video_get_metadata(video_path)
+ ret = video_get_metadata(video_path, self.num_frames)
return ret
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py
index 0ce8ab650d06..6050a5a368a5 100644
--- a/vllm/attention/layer.py
+++ b/vllm/attention/layer.py
@@ -24,6 +24,14 @@
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
+def check_upstream_fa_availability(dtype: torch.dtype):
+ if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
+ ) and current_platform.has_device_capability(80):
+ from transformers.utils import is_flash_attn_2_available
+ return is_flash_attn_2_available()
+ return False
+
+
class Attention(nn.Module):
"""Attention layer.
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 73657666dfdd..783673aa4d96 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -537,6 +537,8 @@ def _placeholder_str(self, modality: ModalityStr,
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|IMAGE|><|vision_end|>"
+ if model_type in ("qwen3_omni_moe", "qwen3_vl_moe", "qwen3_vl"):
+ return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "molmo":
return ""
if model_type == "aria":
@@ -555,6 +557,8 @@ def _placeholder_str(self, modality: ModalityStr,
if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
+ if model_type == "qwen3_omni_moe":
+ return "<|audio_start|><|audio_pad|><|audio_end|>"
if model_type == "minicpmo":
return "()"
raise TypeError(f"Unknown model type: {model_type}")
@@ -567,6 +571,8 @@ def _placeholder_str(self, modality: ModalityStr,
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|VIDEO|><|vision_end|>"
+ if model_type in ("qwen3_omni_moe", "qwen3_vl_moe", "qwen3_vl"):
+ return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"):
return "()"
if model_type.startswith("llava"):
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index ab237a460dd2..2500eab09ce7 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -1168,6 +1168,18 @@ def forward_hpu( # type: ignore[override]
return query_out.type_as(query), key_out.type_as(key)
+def apply_interleaved_rope(x: torch.Tensor,
+ mrope_section: list[int]) -> torch.Tensor:
+ """Apply interleaved MRoPE to 3D rotary embeddings.
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
+ """
+ x_t = x[0].clone()
+ x_t[..., 1:mrope_section[1] * 3:3] = x[1, ..., 1:mrope_section[1] * 3:3]
+ x_t[..., 2:mrope_section[2] * 3:3] = x[2, ..., 2:mrope_section[2] * 3:3]
+ return x_t
+
+
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
@@ -1180,6 +1192,7 @@ def __init__(
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[list[int]] = None,
+ mrope_interleaved: Optional[bool] = False,
) -> None:
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
@@ -1189,6 +1202,7 @@ def __init__(
base, is_neox_style, dtype)
self.mrope_section = mrope_section
+ self.mrope_interleaved = mrope_interleaved
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
@@ -1215,17 +1229,20 @@ def forward(
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
-
- cos = torch.cat([
- m[i]
- for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
- ],
- dim=-1)
- sin = torch.cat([
- m[i]
- for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
- ],
- dim=-1)
+ if self.mrope_interleaved:
+ cos = apply_interleaved_rope(cos, self.mrope_section)
+ sin = apply_interleaved_rope(sin, self.mrope_section)
+ else:
+ cos = torch.cat([
+ m[i] for i, m in enumerate(
+ cos.split(self.mrope_section, dim=-1))
+ ],
+ dim=-1)
+ sin = torch.cat([
+ m[i] for i, m in enumerate(
+ sin.split(self.mrope_section, dim=-1))
+ ],
+ dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
@@ -1292,7 +1309,7 @@ def get_input_positions_tensor(
) -> tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config) and \
- hf_config.model_type == "qwen2_5_omni":
+ hf_config.model_type in ["qwen2_5_omni","qwen3_omni_moe"]:
return cls._omni_get_input_positions_tensor(
input_tokens=input_tokens,
hf_config=hf_config,
@@ -1313,6 +1330,15 @@ def get_input_positions_tensor(
context_len=context_len,
seq_len=seq_len,
)
+ elif hf_config.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
+ return cls._qwen3vl_get_input_positions_tensor(
+ input_tokens=input_tokens,
+ hf_config=hf_config,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ context_len=context_len,
+ seq_len=seq_len,
+ )
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
@@ -1532,6 +1558,386 @@ def _vl_get_input_positions_tensor(
return llm_positions, mrope_position_delta
+ @classmethod
+ def _qwen3vl_get_input_positions_tensor(
+ cls,
+ input_tokens: list[int],
+ hf_config: PretrainedConfig,
+ image_grid_thw: Union[list[list[int]], torch.Tensor],
+ video_grid_thw: Union[list[list[int]], torch.Tensor],
+ context_len: int = 0,
+ seq_len: Optional[int] = None,
+ ) -> tuple[torch.Tensor, int]:
+ """Get mrope input positions and delta value."""
+
+ video_grid_thw = [[1, h, w] for t, h, w in video_grid_thw
+ for _ in range(t)]
+
+ image_token_id = hf_config.image_token_id
+ video_token_id = hf_config.video_token_id
+ vision_start_token_id = hf_config.vision_start_token_id
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
+
+ input_tokens_tensor = torch.tensor(input_tokens)
+ vision_start_indices = torch.argwhere(
+ input_tokens_tensor == vision_start_token_id).squeeze(1)
+ vision_tokens = input_tokens_tensor[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ llm_pos_ids_list: list = []
+
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+
+ image_index, video_index = 0, 0
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+
+ llm_grid_t, llm_grid_h, llm_grid_w = \
+ t, h // spatial_merge_size, w // spatial_merge_size
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
+ -1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
+ llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
+ llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(
+ torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ mrope_position_delta = (llm_positions.max() + 1 -
+ len(input_tokens)).item()
+ llm_positions = llm_positions[:, context_len:seq_len]
+ return llm_positions, mrope_position_delta
+
+ @classmethod
+ def _omni3_get_input_positions_tensor(
+ cls,
+ config,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ use_audio_in_video: bool = False,
+ audio_seqlens: Optional[torch.LongTensor] = None,
+ second_per_grids: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+
+ def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
+ input_lengths_leave = input_lengths % 100
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
+ output_lengths = ((feat_lengths - 1) // 2 + 1 -
+ 1) // 2 + 1 + (input_lengths // 100) * 13
+ return output_lengths
+
+ spatial_merge_size = config.vision_config.spatial_merge_size
+ image_token_id = config.image_token_id
+ video_token_id = config.video_token_id
+ audio_token_id = config.audio_token_id
+ vision_start_token_id = config.vision_start_token_id
+ audio_start_token_id = config.audio_start_token_id
+ position_id_per_seconds = config.position_id_per_seconds
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None
+ or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.zeros(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_idx, video_idx, audio_idx = 0, 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i] == 1]
+ image_nums, video_nums, audio_nums = 0, 0, 0
+ vision_start_indices = torch.argwhere(
+ input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ audio_nums = torch.sum(input_ids == audio_start_token_id)
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = ((vision_tokens == audio_start_token_id).sum()
+ if use_audio_in_video else
+ (vision_tokens == video_token_id).sum())
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos, remain_audios = (image_nums,
+ video_nums,
+ audio_nums)
+ multimodal_nums = (image_nums + audio_nums
+ if use_audio_in_video else image_nums +
+ video_nums + audio_nums)
+ for _ in range(multimodal_nums):
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ if (image_token_id in input_tokens or video_token_id
+ in input_tokens) and (remain_videos > 0
+ or remain_images > 0):
+ ed_vision_start = input_tokens.index(
+ vision_start_token_id, st)
+ else:
+ ed_vision_start = len(input_tokens) + 1
+ if audio_token_id in input_tokens and remain_audios > 0:
+ ed_audio_start = input_tokens.index(
+ audio_start_token_id, st)
+ else:
+ ed_audio_start = len(input_tokens) + 1
+ min_ed = min(ed_vision_start, ed_audio_start)
+ if min_ed == ed_audio_start:
+ text_len = min_ed - st
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].long().max(
+ ) + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(
+ 3, -1) + st_idx)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(
+ torch.arange(bos_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ audio_len = _get_feat_extract_output_lengths(
+ audio_seqlens[audio_idx])
+ llm_pos_ids = torch.arange(audio_len).view(
+ 1, -1).expand(3, -1) + st_idx
+ llm_pos_ids_list.append(llm_pos_ids)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(
+ torch.arange(eos_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ st += text_len + bos_len + audio_len + eos_len
+ audio_idx += 1
+ remain_audios -= 1
+ elif min_ed == ed_vision_start and input_ids[
+ ed_vision_start + 1] == image_token_id:
+ text_len = min_ed - st
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].long().max(
+ ) + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(
+ 3, -1) + st_idx)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(
+ torch.arange(bos_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ grid_t = image_grid_thw[image_idx][0]
+ grid_hs = image_grid_thw[:, 1]
+ grid_ws = image_grid_thw[:, 2]
+ t_index = ((torch.arange(grid_t)) * 1 *
+ position_id_per_seconds)
+ llm_pos_ids = cls._get_llm_pos_ids_for_vision(
+ st_idx, image_idx, spatial_merge_size, t_index,
+ grid_hs, grid_ws)
+ image_len = image_grid_thw[image_idx].prod() // (
+ spatial_merge_size**2)
+ llm_pos_ids_list.append(llm_pos_ids)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(
+ torch.arange(eos_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ st += text_len + bos_len + image_len + eos_len
+ image_idx += 1
+ remain_images -= 1
+ elif min_ed == ed_vision_start and input_ids[
+ ed_vision_start +
+ 1] == video_token_id and not use_audio_in_video:
+ text_len = min_ed - st
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].long().max(
+ ) + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(
+ 3, -1) + st_idx)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(
+ torch.arange(bos_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ grid_t = video_grid_thw[video_idx][0]
+ grid_hs = video_grid_thw[:, 1]
+ grid_ws = video_grid_thw[:, 2]
+ t_index = ((torch.arange(grid_t)) *
+ second_per_grids[video_idx].cpu() *
+ position_id_per_seconds)
+ llm_pos_ids = cls._get_llm_pos_ids_for_vision(
+ st_idx, video_idx, spatial_merge_size, t_index,
+ grid_hs, grid_ws)
+ video_len = video_grid_thw[video_idx].prod() // (
+ spatial_merge_size**2)
+ llm_pos_ids_list.append(llm_pos_ids)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(
+ torch.arange(eos_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ st += text_len + bos_len + video_len + eos_len
+ video_idx += 1
+ remain_videos -= 1
+ elif (min_ed == ed_vision_start
+ and ed_vision_start + 1 == ed_audio_start
+ and use_audio_in_video):
+ text_len = min_ed - st
+ if text_len != 0:
+ st_idx = llm_pos_ids_list[-1].long().max(
+ ) + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(
+ 3, -1) + st_idx)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ bos_len = 1
+ llm_pos_ids_list.append(
+ torch.arange(bos_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ llm_pos_ids_list.append(
+ torch.arange(bos_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ audio_len = _get_feat_extract_output_lengths(
+ audio_seqlens[audio_idx])
+ audio_llm_pos_ids = torch.arange(audio_len).view(
+ 1, -1).expand(3, -1) + st_idx
+ grid_t = video_grid_thw[video_idx][0]
+ grid_hs = video_grid_thw[:, 1]
+ grid_ws = video_grid_thw[:, 2]
+ t_index = ((torch.arange(grid_t)) *
+ second_per_grids[video_idx].cpu() *
+ position_id_per_seconds)
+ video_llm_pos_ids = cls._get_llm_pos_ids_for_vision(
+ st_idx, video_idx, spatial_merge_size, t_index,
+ grid_hs, grid_ws)
+ video_data_index, audio_data_index = 0, 0
+ while (video_data_index < video_llm_pos_ids.shape[-1]
+ and audio_data_index
+ < audio_llm_pos_ids.shape[-1]):
+ if video_llm_pos_ids[0][
+ video_data_index] <= audio_llm_pos_ids[0][
+ audio_data_index]:
+ llm_pos_ids_list.append(
+ video_llm_pos_ids[:, video_data_index:
+ video_data_index + 1])
+ video_data_index += 1
+ else:
+ llm_pos_ids_list.append(
+ audio_llm_pos_ids[:, audio_data_index:
+ audio_data_index + 1])
+ audio_data_index += 1
+ if video_data_index < video_llm_pos_ids.shape[-1]:
+ llm_pos_ids_list.append(
+ video_llm_pos_ids[:, video_data_index:
+ video_llm_pos_ids.shape[-1]])
+ if audio_data_index < audio_llm_pos_ids.shape[-1]:
+ llm_pos_ids_list.append(
+ audio_llm_pos_ids[:, audio_data_index:
+ audio_llm_pos_ids.shape[-1]])
+ video_len = video_grid_thw[video_idx].prod() // (
+ spatial_merge_size**2)
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ eos_len = 1
+ llm_pos_ids_list.append(
+ torch.arange(eos_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ llm_pos_ids_list.append(
+ torch.arange(eos_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ st += (text_len + bos_len * 2 + audio_len + video_len +
+ eos_len * 2)
+ audio_idx += 1
+ video_idx += 1
+ remain_videos -= 1
+ remain_audios -= 1
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].long().max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ llm_positions = torch.cat(llm_pos_ids_list,
+ dim=1).reshape(3, -1)
+
+ position_ids[..., i,
+ attention_mask[i] == 1] = llm_positions.to(
+ position_ids.device)
+ mrope_position_deltas.append(llm_positions.long().max() + 1 -
+ len(input_ids))
+ mrope_position_deltas = torch.tensor(
+ mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas.long()
+ else:
+ position_ids = attention_mask.cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(
+ attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(
+ -1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - torch.sum(
+ attention_mask, dim=-1, keepdim=True)
+ return position_ids, mrope_position_deltas.long()
+
@classmethod
def _omni_get_input_positions_tensor(
cls,
@@ -1563,8 +1969,25 @@ def _omni_get_input_positions_tensor(
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
-
+ model_type = hf_config.model_type
thinker_config = hf_config.thinker_config
+
+ if isinstance(image_grid_thw, list):
+ image_grid_thw = torch.tensor(image_grid_thw)
+ if isinstance(video_grid_thw, list):
+ video_grid_thw = torch.tensor(video_grid_thw)
+
+ if "qwen3_omni" in model_type:
+ (llm_positions,
+ mrope_position_delta) = cls._omni3_get_input_positions_tensor(
+ thinker_config, torch.tensor([input_tokens]), image_grid_thw,
+ video_grid_thw, None, use_audio_in_video,
+ audio_feature_lengths,
+ torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
+ llm_positions = llm_positions.squeeze(1)
+ mrope_position_delta = mrope_position_delta.squeeze()
+ return llm_positions, mrope_position_delta
+
audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
@@ -1577,11 +2000,6 @@ def _omni_get_input_positions_tensor(
tokens_per_second = getattr(thinker_config.vision_config,
"tokens_per_second", 25)
- if isinstance(image_grid_thw, list):
- image_grid_thw = torch.tensor(image_grid_thw)
- if isinstance(video_grid_thw, list):
- video_grid_thw = torch.tensor(video_grid_thw)
-
src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 8e6c76b5a564..b54f187f8e7f 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -299,7 +299,7 @@ def __init__(self,
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer):
super().__init__()
- config = vllm_config.model_config.hf_config
+ config = vllm_config.model_config.hf_config.get_text_config()
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index d8d324d7cd16..af2de72e68ab 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -87,7 +87,7 @@
from habana_frameworks.torch.hpex.kernels import FusedSDPA
# For profile run
-_MAX_FRAMES_PER_VIDEO = 16
+_MAX_FRAMES_PER_VIDEO = 600
# === Vision Inputs === #
diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py
index ca509df0468c..881c9d1b418c 100644
--- a/vllm/model_executor/models/qwen3_moe.py
+++ b/vllm/model_executor/models/qwen3_moe.py
@@ -321,7 +321,7 @@ class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
- config = vllm_config.model_config.hf_config
+ config = vllm_config.model_config.hf_config.get_text_config()
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py
new file mode 100644
index 000000000000..1ae38c81b4cb
--- /dev/null
+++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py
@@ -0,0 +1,1545 @@
+# SPDX-License-Identifier: Apache-2.0
+# Copyright 2025 The Qwen team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
+
+from collections.abc import Iterable, Mapping, Sequence
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import PretrainedConfig
+from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import (
+ Qwen3OmniMoeConfig, Qwen3OmniMoeThinkerConfig)
+from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
+ Qwen3OmniMoeAudioEncoder)
+from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import (
+ Qwen3OmniMoeProcessor)
+from transformers.models.whisper import WhisperFeatureExtractor
+
+from vllm.attention.layer import check_upstream_fa_availability
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import VllmConfig
+from vllm.distributed import get_pp_group
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.qwen2_audio import (Qwen2AudioInputs,
+ Qwen2AudioProcessingInfo)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalKwargsItems, NestedTensors
+from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ PlaceholderFeaturesInfo,
+ PromptReplacement, PromptUpdate)
+from vllm.platforms import _Backend, current_platform
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.tokenizer import decode_tokens
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
+from .qwen2_5_omni_thinker import (Qwen2_5OmniConditionalGenerationMixin,
+ Qwen2_5OmniThinkerDummyInputsBuilder,
+ Qwen2_5OmniThinkerMultiModalProcessor,
+ Qwen2_5OmniThinkerProcessingInfo)
+from .qwen2_5_vl import (Qwen2_5_VisionAttention,
+ Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VLImageInputs,
+ Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoInputs)
+from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
+from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix,
+ merge_multimodal_embeddings)
+from .vision import get_vit_attn_backend
+
+try:
+ import flash_attn
+except (ImportError, ModuleNotFoundError):
+ flash_attn = None
+
+logger = init_logger(__name__)
+is_hpu = current_platform.is_hpu()
+
+if is_hpu:
+ import habana_frameworks.torch as htorch
+ import habana_frameworks.torch.core as htcore
+
+
+class Qwen3_VisionPatchEmbed(nn.Module):
+
+ def __init__(
+ self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ hidden_size: int = 1152,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.hidden_size = hidden_size
+
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
+ self.proj = nn.Conv3d(in_channels,
+ hidden_size,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ L, C = x.shape
+ x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
+ self.patch_size)
+ x = self.proj(x).view(L, self.hidden_size)
+ return x
+
+
+class Qwen3_VisionMLP(nn.Module):
+
+ def __init__(self,
+ in_features: int,
+ hidden_features: int,
+ bias: bool = False,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.linear_fc1 = ColumnParallelLinear(in_features,
+ hidden_features,
+ bias=bias,
+ quant_config=quant_config,
+ return_bias=False,
+ prefix=f"{prefix}.linear_fc1")
+ self.linear_fc2 = RowParallelLinear(hidden_features,
+ in_features,
+ bias=bias,
+ quant_config=quant_config,
+ return_bias=False,
+ prefix=f"{prefix}.linear_fc2")
+ self.act_fn = act_fn
+
+ def forward(self, x: torch.Tensor):
+ mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
+ return mlp_output
+
+
+class Qwen3_VisionBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.norm1 = norm_layer(dim)
+ self.norm2 = norm_layer(dim)
+ self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
+ num_heads=num_heads,
+ projection_size=dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn")
+ self.mlp = Qwen3_VisionMLP(dim,
+ mlp_hidden_dim,
+ act_fn=act_fn,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor,
+ max_seqlen: Optional[int] = None, # Only used for Flash Attention
+ seqlens: Optional[list[int]] = None, # Only used for xFormers
+ attn_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ x = x + self.attn(self.norm1(x),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ max_seqlen=max_seqlen,
+ seqlens=seqlens,
+ attn_mask=attn_mask)
+
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class Qwen3_VisionPatchMerger(nn.Module):
+
+ def __init__(
+ self,
+ d_model: int,
+ context_dim: int,
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
+ spatial_merge_size: int = 2,
+ use_postshuffle_norm: bool = False,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = context_dim * (spatial_merge_size**2)
+
+ self.use_postshuffle_norm = use_postshuffle_norm
+ if self.use_postshuffle_norm:
+ context_dim = self.hidden_size
+
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.ln_q = norm_layer(
+ self.hidden_size if use_postshuffle_norm else context_dim)
+ self.mlp = nn.ModuleList([
+ ColumnParallelLinear(self.hidden_size,
+ self.hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp.0"),
+ nn.GELU(),
+ RowParallelLinear(self.hidden_size,
+ d_model,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp.2"),
+ ])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.use_postshuffle_norm:
+ x = self.ln_q(x.view(-1, self.hidden_size))
+ else:
+ x = self.ln_q(x).view(-1, self.hidden_size)
+
+ mlp_fc1, mlp_act, mlp_fc2 = self.mlp
+ x_parallel, _ = mlp_fc1(x)
+ x_parallel = mlp_act(x_parallel)
+ out, _ = mlp_fc2(x_parallel)
+ return out
+
+
+class Qwen3Omni_VisionTransformer(nn.Module):
+
+ def __init__(
+ self,
+ vision_config,
+ norm_eps: float = 1e-6,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = vision_config.hidden_size
+ self.num_heads = vision_config.num_heads
+ self.image_size = vision_config.image_size
+ self.patch_size = vision_config.patch_size
+ self.spatial_merge_size = vision_config.spatial_merge_size
+ self.spatial_merge_unit = self.spatial_merge_size**2
+ self.temporal_patch_size = vision_config.temporal_patch_size
+ self.apply_vit_abs_pos_embed = vision_config.apply_vit_abs_pos_embed
+ self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
+
+ self.patch_embed = Qwen3_VisionPatchEmbed(
+ patch_size=self.patch_size,
+ temporal_patch_size=self.temporal_patch_size,
+ in_channels=vision_config.in_channels,
+ hidden_size=self.hidden_size,
+ )
+
+ # vit pos embedding, TODO: spatial_patch_size vs patch_size
+ if self.apply_vit_abs_pos_embed:
+ self.pos_embed = nn.Embedding(
+ (self.image_size // self.patch_size)**2, self.hidden_size)
+ else:
+ self.pos_embed = nn.Parameter(
+ torch.empty([
+ 1, (self.image_size // self.patch_size)**2,
+ self.hidden_size
+ ]))
+
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
+ head_dim = self.hidden_size // self.num_heads
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([
+ Qwen3_VisionBlock(
+ dim=self.hidden_size,
+ num_heads=self.num_heads,
+ mlp_hidden_dim=vision_config.intermediate_size,
+ act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ prefix=f"{prefix}.blocks.{layer_idx}")
+ for layer_idx in range(vision_config.depth)
+ ])
+ self.merger = Qwen3_VisionPatchMerger(
+ d_model=vision_config.out_hidden_size,
+ context_dim=self.hidden_size,
+ norm_layer=norm_layer,
+ spatial_merge_size=self.spatial_merge_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.merger",
+ )
+ if self.deepstack_visual_indexes is not None:
+ self.merger_list = nn.ModuleList([
+ Qwen3_VisionPatchMerger(
+ d_model=vision_config.out_hidden_size,
+ context_dim=self.hidden_size,
+ spatial_merge_size=self.spatial_merge_size,
+ use_postshuffle_norm=True,
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ prefix=f"{prefix}.merger_list.{layer_idx}")
+ for layer_idx in range(len(self.deepstack_visual_indexes))
+ ])
+
+ self.attn_backend = get_vit_attn_backend(support_fa=True)
+ if self.attn_backend != _Backend.FLASH_ATTN and \
+ check_upstream_fa_availability(
+ torch.get_default_dtype()):
+ self.attn_backend = _Backend.FLASH_ATTN
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.patch_embed.proj.weight.dtype
+
+ @property
+ def device(self) -> torch.device:
+ return self.patch_embed.proj.weight.device
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(
+ torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ num_grid_per_side = self.image_size // self.patch_size
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ # TODO: use torch instand of np
+ for t, h, w in grid_thw:
+ h_idxs = np.linspace(0, num_grid_per_side - 1, h)
+ w_idxs = np.linspace(0, num_grid_per_side - 1, w)
+
+ h_idxs_floor = h_idxs.astype(int)
+ w_idxs_floor = w_idxs.astype(int)
+ h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side -
+ 1)
+ w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side -
+ 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T +
+ w_idxs_floor[None]).flatten().tolist() * t)
+ idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T +
+ w_idxs_ceil[None]).flatten().tolist() * t)
+ idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
+ w_idxs_floor[None]).flatten().tolist() * t)
+ idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
+ w_idxs_ceil[None]).flatten().tolist() * t)
+
+ weight_list[0].extend(
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t)
+ weight_list[1].extend(
+ ((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
+ weight_list[2].extend(
+ (dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
+ weight_list[3].extend(
+ (dh[None].T * dw[None]).flatten().tolist() * t)
+
+ device = self.pos_embed.weight.device
+ dtype = self.pos_embed.weight.dtype
+
+ p0 = self.pos_embed(
+ torch.tensor(
+ idx_list[0], dtype=torch.long, device=device)) * torch.tensor(
+ weight_list[0], dtype=dtype, device=device)[:, None]
+ p1 = self.pos_embed(
+ torch.tensor(
+ idx_list[1], dtype=torch.long, device=device)) * torch.tensor(
+ weight_list[1], dtype=dtype, device=device)[:, None]
+ p2 = self.pos_embed(
+ torch.tensor(
+ idx_list[2], dtype=torch.long, device=device)) * torch.tensor(
+ weight_list[2], dtype=dtype, device=device)[:, None]
+ p3 = self.pos_embed(
+ torch.tensor(
+ idx_list[3], dtype=torch.long, device=device)) * torch.tensor(
+ weight_list[3], dtype=dtype, device=device)[:, None]
+
+ patch_pos_embeds = p0 + p1 + p2 + p3
+ patch_pos_embeds = patch_pos_embeds.split(
+ [t * h * w for t, h, w in grid_thw])
+ patch_pos_embeds_permute = []
+ m_size = self.spatial_merge_size
+ for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
+ pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size,
+ m_size, -1).permute(0, 1, 3, 2, 4,
+ 5).flatten(0, 4)
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ def compute_attn_mask_seqlen(
+ self,
+ cu_seqlens: torch.Tensor,
+ ) -> tuple[Optional[int], Optional[list[int]]]:
+ max_seqlen, seqlens = None, None
+ if self.attn_backend == _Backend.FLASH_ATTN:
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ elif self.attn_backend == _Backend.XFORMERS:
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ return max_seqlen, seqlens
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ grid_thw: list[list[int]],
+ ) -> torch.Tensor:
+ hidden_states = x.to(device=self.device, dtype=self.dtype)
+ hidden_states = self.patch_embed(hidden_states)
+
+ if self.apply_vit_abs_pos_embed:
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ cu_seqlens = torch.repeat_interleave(
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ dtype=grid_thw.dtype
+ if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ hidden_states = hidden_states.unsqueeze(1)
+ rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
+ max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+
+ hidden_states_list = []
+ deepstack_visual_indexes = self.deepstack_visual_indexes
+
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(hidden_states,
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ max_seqlen=max_seqlen,
+ seqlens=seqlens)
+ if (deepstack_visual_indexes is not None
+ and layer_num in deepstack_visual_indexes):
+ hidden_states_list.append(hidden_states)
+
+ hidden_states = self.merger(hidden_states)
+
+ # processing deepstack
+ if deepstack_visual_indexes is not None:
+ processed_hidden_states_list = [hidden_states]
+ for idx, x in enumerate(hidden_states_list):
+ x = self.merger_list[idx](x)
+ processed_hidden_states_list.append(x)
+ # we cat the original visual features and deepstack
+ # features along the feature dim
+ hidden_states = torch.cat(
+ processed_hidden_states_list,
+ dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
+
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("attn.qkv.", "attn.q.", "q"),
+ ("attn.qkv.", "attn.k.", "k"),
+ ("attn.qkv.", "attn.v.", "v"),
+ ]
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class Qwen3Omni_VisionTransformerStaticShape(Qwen3Omni_VisionTransformer):
+ """
+ Here we overwrite some of the methods of Qwen3Omni_VisionTransformer
+ to make the model more friendly to static shapes. Specifically,
+ we split the forward method into:
+ - pre_attn (dynamic)
+ - forward (static shape)
+ and we should call get_image_embeds instead of forward, allowing
+ the forward method ro run with HPU_Graphs, whereas the
+ pre_attn and post_attn methods are allow to be dynamic.
+ """
+
+ def pad_multimodal_data(self,
+ pixel_values,
+ vision_buckets,
+ constant_value=0):
+
+ desired_number_of_pixels = vision_buckets.get_multimodal_bucket(
+ pixel_values.shape[0])
+ padding_len = desired_number_of_pixels - pixel_values.shape[0]
+ if padding_len <= 0:
+ return pixel_values
+
+ logger_msg = "Padding current number pixel " \
+ + str(pixel_values.shape[0]) \
+ + " to " \
+ + str(desired_number_of_pixels)
+ logger.debug(logger_msg)
+
+ pixel_values = F.pad(pixel_values, (0, 0, 0, padding_len), "constant",
+ constant_value)
+
+ return pixel_values
+
+ def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor,
+ vision_buckets):
+ hidden_states = x.to(device=self.device, dtype=self.dtype)
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+ attention_mask = torch.ones(hidden_states.size(0),
+ 1).to(device=self.device)
+
+ hidden_states = self.pad_multimodal_data(hidden_states, vision_buckets,
+ 0)
+ rotary_pos_emb = self.pad_multimodal_data(rotary_pos_emb,
+ vision_buckets, -100)
+ attention_mask = self.pad_multimodal_data(attention_mask,
+ vision_buckets, 0)
+
+ return hidden_states, rotary_pos_emb, attention_mask
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ rotary_pos_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = x.unsqueeze(1)
+ rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
+ hidden_states_list = []
+ deepstack_visual_indexes = self.deepstack_visual_indexes
+
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(
+ hidden_states,
+ rotary_pos_emb=rotary_pos_emb,
+ attn_mask=attn_mask,
+ cu_seqlens=None,
+ )
+ if (deepstack_visual_indexes is not None
+ and layer_num in deepstack_visual_indexes):
+ hidden_states_list.append(hidden_states)
+
+ hidden_states = self.merger(hidden_states)
+
+ # processing deepstack
+ if deepstack_visual_indexes is not None:
+ processed_hidden_states_list = [hidden_states]
+ for idx, x in enumerate(hidden_states_list):
+ x = self.merger_list[idx](x)
+ processed_hidden_states_list.append(x)
+ # we cat the original visual features
+ # and deepstack features along the feature dim
+ hidden_states = torch.cat(
+ processed_hidden_states_list,
+ dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
+
+ return hidden_states
+
+ def get_image_embeds(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: torch.Tensor,
+ vision_buckets,
+ ) -> torch.Tensor:
+
+ offset = 0
+ results = []
+ # process each image one by one
+ for img_idx in range(grid_thw.shape[0]):
+ img_shape = grid_thw[img_idx, :].unsqueeze(0).clone()
+ # For video, we process frames separately
+ grid_t = grid_thw[img_idx, 0]
+ img_shape[0, 0] = 1
+ curr_img_size = img_shape.prod()
+ for _ in torch.arange(0, grid_t):
+ pixel_values_curr_img = pixel_values[offset:offset +
+ curr_img_size, :]
+
+ offset += curr_img_size
+
+ (pixel_values_curr_img_padded, rot_pos_emb,
+ attention_mask) = self.pre_attn(pixel_values_curr_img,
+ img_shape, vision_buckets)
+
+ fullatt_block_attn_mask = \
+ attention_mask.squeeze(1).unsqueeze(0) * attention_mask
+
+ extra_forward_kwargs = {}
+ if htorch.utils.internal.is_lazy():
+ padded_len = pixel_values_curr_img_padded.shape[0]
+ use_graph = vision_buckets.use_graph(padded_len)
+ extra_forward_kwargs.update(
+ {"bypass_hpu_graphs": not use_graph})
+
+ htcore.mark_step()
+ hidden_states = self.forward(pixel_values_curr_img_padded,
+ rotary_pos_emb=rot_pos_emb,
+ attn_mask=fullatt_block_attn_mask,
+ **extra_forward_kwargs)
+ htcore.mark_step()
+
+ post_embed_size = curr_img_size // self.spatial_merge_unit
+ results += [hidden_states[:post_embed_size, :]]
+
+ results_cat = torch.concat(results)
+ image_embeds = results_cat
+ return image_embeds
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ "deepstack_input_embeds": 0
+ })
+class Qwen3MoeLLMModel(Qwen3MoeModel):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+
+ self.deepstack_multiscale_layer_start = 1
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ deepstack_input_embeds: Optional[IntermediateTensors] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+ for layer_idx, layer in enumerate(
+ self.layers[self.start_layer:self.end_layer]):
+ layer_idx = layer_idx + self.start_layer
+
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ residual,
+ )
+
+ if deepstack_input_embeds is not None and \
+ layer_idx in range(0, len(deepstack_input_embeds)):
+ hidden_states = hidden_states + deepstack_input_embeds[
+ f"deepstack_input_embeds_{layer_idx}"]
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super(Qwen3MoeForCausalLM, self).__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+ self.model = Qwen3MoeLLMModel(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+ self.lm_head = ParallelLMHead(config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config)
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+
+
+class Qwen3OmniMoeThinkerProcessingInfo(Qwen2AudioProcessingInfo,
+ Qwen2_5_VLProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(Qwen3OmniMoeConfig).thinker_config
+
+ def get_hf_processor(self, **kwargs: object) -> Qwen3OmniMoeProcessor:
+ processor = self.ctx.get_hf_processor(
+ Qwen3OmniMoeProcessor,
+ use_fast=kwargs.pop("use_fast", True),
+ **kwargs,
+ )
+ if not hasattr(processor, "audio_token"):
+ processor.audio_token = "<|audio_pad|>"
+ if not hasattr(processor, "image_token"):
+ processor.image_token = "<|image_pad|>"
+ if not hasattr(processor, "video_token"):
+ processor.video_token = "<|video_pad|>"
+ return processor
+
+ def get_feature_extractor(self, **kwargs: object):
+ hf_processor = self.get_hf_processor(**kwargs)
+ feature_extractor = hf_processor.feature_extractor # type: ignore
+ assert isinstance(feature_extractor, WhisperFeatureExtractor)
+ return feature_extractor
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"audio": None, "image": None, "video": None}
+
+
+class Qwen3OmniMoeThinkerMultiModalProcessor(
+ Qwen2_5OmniThinkerMultiModalProcessor, ):
+
+ def _get_feat_extract_output_lengths(
+ self, input_lengths: torch.Tensor) -> torch.Tensor:
+ input_lengths_leave = input_lengths % 100
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
+ output_lengths = ((feat_lengths - 1) // 2 + 1 -
+ 1) // 2 + 1 + (input_lengths // 100) * 13
+ return feat_lengths, output_lengths
+
+ def _maybe_apply_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ prompt_ids: list[int],
+ mm_kwargs: MultiModalKwargsItems,
+ is_update_applied: bool,
+ ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
+ """
+ Qwen3-Omni reimplements this function to handle `use_audio_in_video`.
+ """
+ unbound_prompt_updates = self._get_prompt_updates(
+ mm_items,
+ hf_processor_mm_kwargs,
+ mm_kwargs,
+ )
+ mm_prompt_updates = self._bind_and_group_updates(
+ unbound_prompt_updates)
+ mm_item_counts = mm_items.get_all_counts()
+ self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
+
+ use_audio_in_video = (all(
+ item["use_audio_in_video"].data
+ for item in mm_kwargs["video"]) if "video" in mm_kwargs else False)
+
+ if use_audio_in_video and "video" in mm_item_counts:
+ assert "audio" in mm_item_counts
+ mm_item_counts["audio"] -= mm_item_counts["video"]
+
+ if is_update_applied:
+ prompt_ids = self._get_raw_input_ids(prompt_ids,
+ use_audio_in_video)
+
+ (
+ prompt_ids,
+ prompt,
+ mm_placeholders,
+ ) = self._apply_prompt_updates(
+ prompt_ids,
+ mm_prompt_updates,
+ mm_item_counts,
+ )
+ self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
+
+ tokenizer = self.info.get_tokenizer()
+ prompt = decode_tokens(tokenizer, prompt_ids)
+
+ return prompt_ids, prompt, mm_placeholders
+
+ def get_updates_use_audio_in_video(
+ self,
+ thinker_config: PretrainedConfig,
+ audio_len: int,
+ video_grid_thw: Union[list[int], torch.Tensor],
+ video_second_per_grid_t: float,
+ ) -> list[int]:
+ shift = 0
+ audio_token_id = thinker_config.audio_token_id
+ video_token_id = thinker_config.video_token_id
+ audio_start_token_id = thinker_config.audio_start_token_id
+ audio_end_token_id = thinker_config.audio_end_token_id
+ spatial_merge_size = thinker_config.vision_config.spatial_merge_size
+ position_id_per_seconds = thinker_config.position_id_per_seconds
+ audio_token_indices = np.arange(next(iter([audio_len])))
+ curr_video_grid_thw = next(iter([video_grid_thw]))
+ height = curr_video_grid_thw[1] // spatial_merge_size
+ width = curr_video_grid_thw[2] // spatial_merge_size
+ video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(
+ -1, 1, 1)
+ video_token_indices = np.broadcast_to(
+ video_token_indices,
+ (video_token_indices.shape[0], height, width)).reshape(-1)
+ video_token_indices = ((video_token_indices + shift) *
+ next(iter([video_second_per_grid_t])) *
+ position_id_per_seconds)
+ video_data_index, audio_data_index = 0, 0
+ updates = [audio_start_token_id]
+ while video_data_index < len(
+ video_token_indices) and audio_data_index < len(
+ audio_token_indices):
+ if video_token_indices[video_data_index] <= audio_token_indices[
+ audio_data_index]:
+ updates += [video_token_id]
+ video_data_index += 1
+ else:
+ updates += [audio_token_id]
+ audio_data_index += 1
+ if video_data_index < len(video_token_indices):
+ updates += [video_token_id
+ ] * (len(video_token_indices) - video_data_index)
+ if audio_data_index < len(audio_token_indices):
+ updates += [audio_token_id
+ ] * (len(audio_token_indices) - audio_data_index)
+ updates += [audio_end_token_id]
+ return updates
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ tokenizer = self.info.get_tokenizer()
+ image_processor = self.info.get_image_processor(
+ **hf_processor_mm_kwargs)
+ vocab = tokenizer.get_vocab()
+
+ audio_token = processor.audio_token
+ image_token = processor.image_token
+ video_token = processor.video_token
+ audio_token_id = vocab[audio_token]
+ image_token_id = vocab[image_token]
+ video_token_id = vocab[video_token]
+
+ audio_feature_lengths = out_mm_kwargs.get("audio_feature_lengths")
+ feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
+ if audio_feature_lengths is None and feature_attention_mask is None:
+ audio_output_lengths = []
+ elif audio_feature_lengths is not None:
+ _, audio_output_lens = self._get_feat_extract_output_lengths(
+ audio_feature_lengths)
+ audio_output_lengths = audio_output_lens.tolist()
+ elif feature_attention_mask is not None:
+ assert isinstance(feature_attention_mask, torch.Tensor)
+ _, audio_output_lens = self._get_feat_extract_output_lengths(
+ feature_attention_mask.sum(-1))
+ audio_output_lengths = audio_output_lens.tolist()
+
+ # number of audios read from video.
+ audio_in_video_item_idx = 0
+ audio_item_idx = 0
+
+ def get_replacement_qwen2_audio(item_idx: int):
+ nonlocal audio_item_idx
+ item_idx += audio_in_video_item_idx
+
+ audio_item_idx += 1
+
+ num_features = audio_output_lengths[item_idx]
+ if num_features == 0:
+ audios = mm_items.get_items("audio", AudioProcessorItems)
+ audio = audios.get(item_idx)
+ raise ValueError(
+ f"The audio {audio} (len={len(audio)}) is too short "
+ "to be represented inside the model")
+
+ return [audio_token_id] * num_features
+
+ def get_replacement_qwen2_vision(item_idx: int, modality: str):
+ grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
+ assert isinstance(grid_thw, torch.Tensor)
+ merge_length = image_processor.merge_size**2
+
+ token_id = image_token_id if modality == "image" else video_token_id
+ return [token_id] * (int(grid_thw.prod()) // merge_length)
+
+ use_audio_in_video = hf_processor_mm_kwargs.get(
+ "use_audio_in_video", False)
+ thinker_config = self.info.get_hf_config()
+
+ def get_replacement_qwen2_use_audio_in_video(item_idx: int):
+ nonlocal audio_in_video_item_idx
+ audio_num_features = audio_output_lengths[audio_item_idx +
+ item_idx]
+ video_grid_thw = out_mm_kwargs["video_grid_thw"][item_idx]
+
+ audio_in_video_item_idx += 1
+
+ second_per_grid_ts = hf_processor_mm_kwargs.get(
+ "second_per_grid_ts", None)
+ if second_per_grid_ts:
+ video_second_per_grid_t = second_per_grid_ts[item_idx]
+ else:
+ video_second_per_grid_t = 1.0
+
+ return self.get_updates_use_audio_in_video(
+ thinker_config=thinker_config,
+ audio_len=audio_num_features,
+ video_grid_thw=video_grid_thw,
+ video_second_per_grid_t=video_second_per_grid_t,
+ )
+
+ video_replacement_fn = (
+ get_replacement_qwen2_use_audio_in_video if use_audio_in_video else
+ partial(get_replacement_qwen2_vision, modality="video"))
+
+ return [
+ PromptReplacement(
+ modality="audio",
+ target=audio_token,
+ replacement=get_replacement_qwen2_audio,
+ ),
+ PromptReplacement(
+ modality="image",
+ target=image_token,
+ replacement=partial(get_replacement_qwen2_vision,
+ modality="image"),
+ ),
+ PromptReplacement(
+ modality="video",
+ target=video_token,
+ replacement=video_replacement_fn,
+ ),
+ ]
+
+ def _validate_mm_placeholders(
+ self,
+ mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
+ mm_item_counts: Mapping[str, int],
+ ) -> None:
+ BaseMultiModalProcessor[
+ Qwen2_5OmniThinkerProcessingInfo]._validate_mm_placeholders(
+ self, mm_placeholders, mm_item_counts)
+
+ def _get_raw_input_ids(
+ self,
+ token_ids: list[int],
+ use_audio_in_video: bool = False,
+ ) -> list[int]:
+
+ tokenizer = self.info.get_tokenizer()
+ vision_bos_token = tokenizer.encode(tokenizer.vision_bos_token)[0]
+ vision_eos_token = tokenizer.encode(tokenizer.vision_eos_token)[0]
+ audio_bos_token = tokenizer.encode(tokenizer.audio_bos_token)[0]
+ audio_eos_token = tokenizer.encode(tokenizer.audio_eos_token)[0]
+ audio_token = tokenizer.encode("<|audio_pad|>")[0]
+ image_token = tokenizer.encode("<|image_pad|>")[0]
+ video_token = tokenizer.encode("<|video_pad|>")[0]
+
+ result = token_ids[:]
+ if use_audio_in_video:
+ while True:
+ start = None
+ for i in range(len(result) - 1):
+ if result[i:i + 2] == [vision_bos_token, audio_bos_token]:
+ start = i
+ break
+ if start is not None:
+ end = None
+ for i in range(start + 2, len(result) - 1):
+ if result[i:i +
+ 2] == [audio_eos_token, vision_eos_token]:
+ end = i
+ break
+ if end is not None:
+ result = result[:start] + [
+ vision_bos_token, video_token, vision_eos_token
+ ] + result[end + 2:]
+ else:
+ break
+
+ for mm_token in [audio_token, image_token, video_token]:
+ compressed = []
+ for x in result:
+ if x != mm_token or (not compressed
+ or compressed[-1] != mm_token):
+ compressed.append(x)
+ result = compressed
+
+ return result
+
+
+class Qwen3OmniMoeConditionalGenerationMixin(
+ Qwen2_5OmniConditionalGenerationMixin):
+
+ def _validate_and_reshape_mm_tensor(self,
+ mm_input: object,
+ name: str,
+ dim: int = 0) -> torch.Tensor:
+ if not isinstance(mm_input, (torch.Tensor, list)):
+ raise ValueError(f"Incorrect type of {name}. "
+ f"Got type: {type(mm_input)}")
+ if name == "feature_attention_mask":
+ dim = -1
+ if isinstance(mm_input, torch.Tensor):
+ return torch.concat(list(mm_input), dim=dim)
+ else:
+ if isinstance(mm_input[0], list):
+ return torch.concat([
+ torch.concat(mm_input[i], dim=dim)
+ for i in range(len(mm_input))
+ ],
+ dim=dim)
+ else:
+ return torch.concat(mm_input, dim=dim)
+
+ def _get_feat_extract_output_lengths(
+ self, input_lengths: torch.Tensor) -> torch.Tensor:
+ input_lengths_leave = input_lengths % 100
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
+ output_lengths = ((feat_lengths - 1) // 2 + 1 -
+ 1) // 2 + 1 + (input_lengths // 100) * 13
+ return output_lengths, output_lengths
+
+ def _process_audio_input(
+ self,
+ audio_input: Qwen2AudioInputs,
+ audio_hashes: list[str] = None,
+ cached_audio_features: torch.Tensor = None,
+ ) -> torch.Tensor:
+
+ input_features = audio_input["input_features"]
+ audio_feature_lengths = audio_input["audio_feature_lengths"]
+
+ if input_features.ndim == 3:
+ assert input_features.shape[0] == 1
+ input_features = input_features.squeeze(0)
+
+ if not isinstance(audio_feature_lengths, torch.Tensor):
+ audio_feature_lengths = torch.cat(audio_feature_lengths)
+ if audio_feature_lengths.ndim == 2:
+ audio_feature_lengths = audio_feature_lengths.reshape(-1)
+
+ audio_feat_lengths, audio_output_lengths = (
+ self._get_feat_extract_output_lengths(audio_feature_lengths))
+
+ audio_outputs = self.audio_tower(
+ input_features.to(self.audio_tower.dtype),
+ feature_lens=audio_feature_lengths,
+ aftercnn_lens=audio_feat_lengths,
+ )
+ audio_features = audio_outputs.last_hidden_state
+ return audio_features.split(audio_output_lengths.tolist())
+
+ def _process_image_input(
+ self,
+ image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
+ if image_input["type"] == "image_embeds":
+ return image_input["image_embeds"].type(self.visual.dtype)
+
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+
+ pixel_values = image_input["pixel_values"].type(self.visual.dtype)
+ if is_hpu:
+ assert isinstance(self.visual,
+ Qwen3Omni_VisionTransformerStaticShape)
+ image_embeds = self.visual.get_image_embeds(
+ pixel_values,
+ grid_thw=grid_thw,
+ vision_buckets=self.vision_buckets,
+ )
+ else:
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
+ # Split concatenated embeddings for each image item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return image_embeds.split(sizes.tolist())
+
+ def _process_video_input(
+ self,
+ video_input: Qwen2_5_VLVideoInputs,
+ video_hashes: list[str] = None,
+ cached_video_embeds: torch.Tensor = None) -> torch.Tensor:
+ if video_input["type"] == "video_embeds":
+ return video_input["video_embeds"].type(self.visual.dtype)
+
+ grid_thw = video_input["video_grid_thw"]
+ assert grid_thw.ndim == 2
+
+ pixel_values_videos = video_input["pixel_values_videos"].type(
+ self.visual.dtype)
+ if is_hpu:
+ assert isinstance(self.visual,
+ Qwen3Omni_VisionTransformerStaticShape)
+ video_embeds = self.visual.get_image_embeds(
+ pixel_values_videos,
+ grid_thw=grid_thw,
+ vision_buckets=self.vision_buckets,
+ )
+ else:
+ video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
+ # Split concatenated embeddings for each video item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return video_embeds.split(sizes.tolist())
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Qwen3OmniMoeThinkerMultiModalProcessor,
+ info=Qwen3OmniMoeThinkerProcessingInfo,
+ dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
+)
+class Qwen3OmniMoeThinkerForConditionalGeneration(
+ nn.Module,
+ SupportsMultiModal,
+ SupportsPP,
+ Qwen3OmniMoeConditionalGenerationMixin,
+):
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ "thinker.lm_head.": "language_model.lm_head.",
+ "thinker.model.": "language_model.model.",
+ "thinker.": "",
+ })
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
+ if modality.startswith("image"):
+ return "<|vision_start|><|image_pad|><|vision_end|>"
+ if modality.startswith("video"):
+ return "<|vision_start|><|video_pad|><|vision_end|>"
+ if modality.startswith("audio"):
+ return "<|audio_start|><|audio_pad|><|audio_end|>"
+
+ raise ValueError("Only image, video or audio modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ thinker_config: Qwen3OmniMoeThinkerConfig = (
+ vllm_config.model_config.hf_config.thinker_config)
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+ self.config = thinker_config
+ self.config.architectures = [
+ "Qwen3OmniMoeThinkerForConditionalGeneration"
+ ]
+ self.text_dim = thinker_config.text_config.hidden_size
+ self.multimodal_config = multimodal_config
+
+ # force "use_flash_attention_2=True" to audio tower to align
+ # the results.
+ if flash_attn is not None:
+ audio_config = thinker_config.audio_config
+ audio_config._attn_implementation_autoset = True
+ audio_config._attn_implementation = "flash_attention_2"
+ else:
+ logger.warning(
+ "flash_attn is not available, the model may not yield the "
+ "exactly same result as the transformers implementation "
+ "in the audio tower part.")
+
+ self.audio_tower = Qwen3OmniMoeAudioEncoder(
+ thinker_config.audio_config)
+ if is_hpu:
+ qwen3omni_visionTransformer = Qwen3Omni_VisionTransformerStaticShape
+ else:
+ qwen3omni_visionTransformer = Qwen3Omni_VisionTransformer
+ self.visual = qwen3omni_visionTransformer(
+ vision_config=thinker_config.vision_config,
+ norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "visual"),
+ )
+ self.quant_config = quant_config
+
+ self.language_model = Qwen3MoeLLMForCausalLM(
+ vllm_config=vllm_config.with_hf_config(
+ thinker_config.text_config,
+ architectures=["Qwen3MoeForCausalLM"]),
+ prefix=maybe_prefix(prefix, "language_model"))
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ self.use_deepstack = hasattr(thinker_config.vision_config,
+ 'deepstack_visual_indexes')
+ self.deepstack_num_level = len(
+ thinker_config.vision_config.deepstack_visual_indexes
+ ) if self.use_deepstack else 0
+ # register buffer for deepstack
+ self.deepstack_input_embeds = [
+ torch.zeros(1, vllm_config.scheduler_config.max_num_batched_tokens,
+ thinker_config.text_config.hidden_size)
+ for _ in range(self.deepstack_num_level)
+ ] if self.use_deepstack else None
+ self.visual_dim = thinker_config.vision_config.out_hidden_size
+ self.multiscale_dim = self.visual_dim * self.deepstack_num_level
+
+ def _get_deepstack_input_embeds(self) -> IntermediateTensors:
+ # get deepstack_input_embeds from buffer, and clear the buffer
+ if self.deepstack_input_embeds is None:
+ return None
+ return IntermediateTensors({
+ f"deepstack_input_embeds_{idx}":
+ self.deepstack_input_embeds[idx]
+ for idx in range(self.deepstack_num_level)
+ })
+
+ def _set_deepstack_input_embeds(
+ self, deepstack_input_embeds: torch.Tensor) -> None:
+ if deepstack_input_embeds is None:
+ self.deepstack_input_embeds = None
+ return
+
+ self.deepstack_input_embeds = [
+ deepstack_input_embeds[idx].clone()
+ for idx in range(self.deepstack_num_level)
+ ]
+
+ def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
+ # clear deepstack_input_embeds in buffer
+ if num_tokens > 0:
+ for idx in range(self.deepstack_num_level):
+ self.deepstack_input_embeds[idx][:num_tokens].zero_()
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ mm_input_by_modality = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if (input_key in ("pixel_values", "image_embeds")
+ and "image" not in mm_input_by_modality):
+ mm_input_by_modality["image"] = (
+ self._parse_and_validate_image_input(**kwargs))
+ if (input_key in ("pixel_values_videos", "video_embeds")
+ and "video" not in mm_input_by_modality):
+ mm_input_by_modality["video"] = (
+ self._parse_and_validate_video_input(**kwargs))
+ if (input_key in ("input_audio_features")
+ and "audio" not in mm_input_by_modality):
+ mm_input_by_modality["audio"] = (
+ self._parse_and_validate_audio_input(**kwargs))
+ return mm_input_by_modality
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def get_multimodal_embeddings(
+ self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
+ **kwargs)
+ if not mm_input_by_modality:
+ return []
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in mm_input_by_modality:
+ multimodal_input = mm_input_by_modality[modality]
+ if modality == "image":
+ vision_embeddings = self._process_image_input(multimodal_input)
+ multimodal_embeddings += vision_embeddings
+ if modality == "video":
+ video_embeddings = self._process_video_input(multimodal_input)
+ multimodal_embeddings += video_embeddings
+ if modality == "audio":
+ audio_embeddings = self._process_audio_input(multimodal_input)
+ multimodal_embeddings += audio_embeddings
+ return multimodal_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ deepstack_input_embeds = None
+ if multimodal_embeddings is not None and len(
+ multimodal_embeddings) != 0:
+ # TODO (ywang96): support overlapping modalitiy embeddings so that
+ # `use_audio_in_video` will work on V1.
+ # split the feat dim to obtain multi-scale visual feature
+ if self.visual.deepstack_visual_indexes is not None:
+ multiscale_len = len(self.visual.deepstack_visual_indexes)
+ multimodal_embeddings_multiscale = []
+ for index, embeddings in enumerate(multimodal_embeddings):
+ if embeddings.shape[
+ -1] != self.config.text_config.hidden_size:
+ visual_dim = embeddings.shape[-1] // (multiscale_len +
+ 1)
+ main_dim, multi_dim = (visual_dim,
+ visual_dim * multiscale_len)
+ embeddings_main, embeddings_multiscale = torch.split(
+ embeddings, [main_dim, multi_dim], dim=-1)
+ multimodal_embeddings[index] = embeddings_main
+ multimodal_embeddings_multiscale.append(
+ embeddings_multiscale)
+ if len(multimodal_embeddings_multiscale) > 0:
+ deepstack_input_embeds = inputs_embeds.new_zeros(
+ inputs_embeds.size(0), inputs_embeds.shape[1],
+ multiscale_len * inputs_embeds.size(1))
+ deepstack_input_embeds = merge_multimodal_embeddings(
+ input_ids,
+ deepstack_input_embeds,
+ multimodal_embeddings_multiscale,
+ placeholder_token_id=[
+ self.config.image_token_id,
+ self.config.video_token_id
+ ],
+ )
+ deepstack_input_embeds = deepstack_input_embeds.view(
+ inputs_embeds.shape[0], inputs_embeds.shape[1],
+ multiscale_len * visual_dim)
+
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ multimodal_embeddings,
+ [
+ self.config.image_token_id,
+ self.config.video_token_id,
+ self.config.audio_token_id,
+ ],
+ )
+ # self._set_deepstack_input_embeds(deepstack_input_embeds)
+ if deepstack_input_embeds is None:
+ deepstack_input_embeds = torch.zeros(inputs_embeds.size(0),
+ inputs_embeds.size(1),
+ multiscale_len * visual_dim,
+ device=inputs_embeds.device,
+ dtype=inputs_embeds.dtype)
+ inputs_embeds = torch.cat([inputs_embeds, deepstack_input_embeds],
+ dim=2)
+ return inputs_embeds
+
+ def get_multimodal_embeddings_v0(
+ self, **kwargs: object) -> Optional[NestedTensors]:
+ audio_input = self._parse_and_validate_audio_input(**kwargs)
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ video_input = self._parse_and_validate_video_input(**kwargs)
+
+ if audio_input is None and image_input is None and video_input is None:
+ return None
+
+ multimodal_embeddings: list[tuple[NestedTensors, str]] = []
+
+ if audio_input is not None:
+ audio_embeds = self._process_audio_input(audio_input)
+ multimodal_embeddings.append((audio_embeds, "audio"))
+ if image_input is not None:
+ image_embeds = self._process_image_input(image_input)
+ multimodal_embeddings.append((image_embeds, "image"))
+ if video_input is not None:
+ video_embeds = self._process_video_input(video_input)
+ multimodal_embeddings.append((video_embeds, "video"))
+ return multimodal_embeddings
+
+ def get_input_embeddings_v0(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[NestedTensors] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(
+ input_ids)[:, :, :self.text_dim]
+
+ if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
+ # self._set_deepstack_input_embeds(None)
+ return inputs_embeds
+
+ use_deepstack = self.visual.deepstack_visual_indexes is not None
+ if use_deepstack:
+ multiscale_len = len(self.visual.deepstack_visual_indexes)
+ visual_dim = self.text_dim
+ deepstack_input_embeds = None
+ for embeddings, modality in multimodal_embeddings:
+ if modality in ["image", "video"]:
+ deepstack_input_embeds = torch.zeros_like(
+ inputs_embeds).unsqueeze(2).repeat(
+ 1, 1, len(self.visual.deepstack_visual_indexes),
+ 1).flatten(2)
+ break
+
+ for embeddings, modality in multimodal_embeddings:
+ if modality == "audio":
+ placeholder_token_id = self.config.audio_token_id
+ if modality == "image":
+ placeholder_token_id = self.config.image_token_id
+ if modality == "video":
+ placeholder_token_id = self.config.video_token_id
+ if use_deepstack and modality in ["image", "video"]:
+ embeddings = torch.cat(embeddings)
+ embeddings, embeddings_multiscale = embeddings.split(
+ [visual_dim, visual_dim * multiscale_len], dim=-1)
+ deepstack_input_embeds = merge_multimodal_embeddings(
+ input_ids,
+ deepstack_input_embeds,
+ embeddings_multiscale,
+ placeholder_token_id=placeholder_token_id,
+ )
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, embeddings, placeholder_token_id)
+
+ if use_deepstack and deepstack_input_embeds is not None:
+ # deepstack_input_embeds = deepstack_input_embeds.view(
+ # inputs_embeds.shape[0], inputs_embeds.shape[1],
+ # multiscale_len, visual_dim).permute(2, 0, 1, 3).contiguous()
+ # self._set_deepstack_input_embeds(deepstack_input_embeds)
+ inputs_embeds = torch.cat((inputs_embeds, deepstack_input_embeds),
+ dim=-1)
+
+ return inputs_embeds
+
+ # self._set_deepstack_input_embeds(None)
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ deepstack_input_embeds = None
+ if self.use_deepstack and inputs_embeds is not None and get_pp_group(
+ ).is_first_rank and inputs_embeds.size(2) > self.text_dim:
+ multiscale_len = len(self.visual.deepstack_visual_indexes)
+ input_embeds_multiscale = inputs_embeds[:, :, self.text_dim:]
+ inputs_embeds = inputs_embeds[:, :, :self.text_dim]
+ input_embeds_multiscale = input_embeds_multiscale.view(
+ inputs_embeds.shape[0], inputs_embeds.shape[1], multiscale_len,
+ self.text_dim).permute(2, 0, 1, 3).contiguous()
+ deepstack_input_embeds = IntermediateTensors({
+ f"deepstack_input_embeds_{idx}":
+ input_embeds_multiscale[idx]
+ for idx in range(self.deepstack_num_level)
+ })
+ # input_embeds_multiscale = self._get_deepstack_input_embeds()
+ else:
+ deepstack_input_embeds = None
+
+ hidden_states = self.language_model.model(
+ input_ids,
+ positions,
+ intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ # args for deepstack
+ deepstack_input_embeds=deepstack_input_embeds,
+ )
+
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self,
+ skip_prefixes=["talker.", "code2wav."],
+ )
+ loaded_weights = loader.load_weights(weights,
+ mapper=self.hf_to_vllm_mapper)
+
+ return loaded_weights
diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py
new file mode 100644
index 000000000000..417e2d847a6c
--- /dev/null
+++ b/vllm/model_executor/models/qwen3_vl.py
@@ -0,0 +1,1647 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright 2025 The vLLM team.
+# Copyright 2025 The Qwen Team.
+# Copyright 2025 The HuggingFace Inc. team.
+# All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
+from collections.abc import Iterable, Mapping, Sequence
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import BatchFeature
+from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
+from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
+from transformers.models.qwen3_vl import (Qwen3VLProcessor,
+ Qwen3VLVideoProcessor)
+from transformers.models.qwen3_vl.configuration_qwen3_vl import (
+ Qwen3VLConfig, Qwen3VLVisionConfig)
+from transformers.video_utils import VideoMetadata
+
+from vllm.attention.layer import check_upstream_fa_availability
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import VllmConfig
+from vllm.distributed import get_pp_group
+from vllm.logger import init_logger
+from vllm.model_executor import SamplingMetadata
+from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.quantization.gptq import GPTQConfig
+from vllm.model_executor.layers.quantization.gptq_marlin import (
+ GPTQMarlinConfig)
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalKwargsItems, VideoItem)
+from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
+ MultiModalDataParser)
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ PromptReplacement, PromptUpdate,
+ PromptUpdateDetails)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.platforms import _Backend, current_platform
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.config import uses_mrope
+from vllm.utils import is_list_of
+
+from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
+ SupportsMultiModal, SupportsPP)
+from .qwen2_5_vl import (Qwen2_5_VisionAttention,
+ Qwen2_5_VisionRotaryEmbedding,
+ Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLImageInputs,
+ Qwen2_5_VLImagePixelInputs,
+ Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLVideoInputs,
+ Qwen2_5_VLVideoPixelInputs)
+from .qwen2_vl import Qwen2VLProcessingInfo
+from .qwen3 import Qwen3ForCausalLM, Qwen3Model
+from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
+ maybe_prefix, merge_multimodal_embeddings)
+from .vision import get_vit_attn_backend
+
+logger = init_logger(__name__)
+is_hpu = current_platform.is_hpu()
+
+if is_hpu:
+ import habana_frameworks.torch as htorch
+ import habana_frameworks.torch.core as htcore
+
+
+class Qwen3_VisionPatchEmbed(nn.Module):
+
+ def __init__(
+ self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ hidden_size: int = 1152,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.hidden_size = hidden_size
+
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
+ self.proj = nn.Conv3d(in_channels,
+ hidden_size,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ L, C = x.shape
+ x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
+ self.patch_size)
+ x = self.proj(x).view(L, self.hidden_size)
+ return x
+
+
+class Qwen3_VisionMLP(nn.Module):
+
+ def __init__(self,
+ in_features: int,
+ hidden_features: int,
+ bias: bool = False,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.linear_fc1 = ColumnParallelLinear(in_features,
+ hidden_features,
+ bias=bias,
+ quant_config=quant_config,
+ return_bias=False,
+ prefix=f"{prefix}.linear_fc1")
+ self.linear_fc2 = RowParallelLinear(hidden_features,
+ in_features,
+ bias=bias,
+ quant_config=quant_config,
+ return_bias=False,
+ prefix=f"{prefix}.linear_fc2")
+ self.act_fn = act_fn
+
+ def forward(self, x: torch.Tensor):
+ mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
+ return mlp_output
+
+
+class Qwen3_VisionBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.norm1 = norm_layer(dim)
+ self.norm2 = norm_layer(dim)
+ self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
+ num_heads=num_heads,
+ projection_size=dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn")
+ self.mlp = Qwen3_VisionMLP(dim,
+ mlp_hidden_dim,
+ act_fn=act_fn,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor,
+ max_seqlen: Optional[int] = None, # Only used for Flash Attention
+ seqlens: Optional[list[int]] = None, # Only used for xFormers
+ attn_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ x = x + self.attn(self.norm1(x),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ max_seqlen=max_seqlen,
+ seqlens=seqlens,
+ attn_mask=attn_mask)
+
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class Qwen3_VisionPatchMerger(nn.Module):
+
+ def __init__(
+ self,
+ d_model: int,
+ context_dim: int,
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
+ spatial_merge_size: int = 2,
+ use_postshuffle_norm: bool = False,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = context_dim * (spatial_merge_size**2)
+
+ self.use_postshuffle_norm = use_postshuffle_norm
+ if self.use_postshuffle_norm:
+ context_dim = self.hidden_size
+
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.use_postshuffle_norm = use_postshuffle_norm
+ self.norm = norm_layer(
+ self.hidden_size if use_postshuffle_norm else context_dim)
+ self.linear_fc1 = ColumnParallelLinear(self.hidden_size,
+ self.hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.linear_fc1")
+ self.act_fn = nn.GELU()
+ self.linear_fc2 = RowParallelLinear(self.hidden_size,
+ d_model,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.linear_fc2")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.use_postshuffle_norm:
+ x = self.norm(x.view(-1, self.hidden_size))
+ else:
+ x = self.norm(x).view(-1, self.hidden_size)
+
+ x_parallel, _ = self.linear_fc1(x)
+ x_parallel = self.act_fn(x_parallel)
+ out, _ = self.linear_fc2(x_parallel)
+ return out
+
+
+class Qwen3_VisionTransformer(nn.Module):
+
+ def __init__(
+ self,
+ vision_config: Qwen3VLVisionConfig,
+ norm_eps: float = 1e-6,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = vision_config.hidden_size
+ self.num_heads = vision_config.num_heads
+ self.num_position_embeddings = vision_config.num_position_embeddings
+ self.patch_size = vision_config.patch_size
+ self.spatial_merge_size = vision_config.spatial_merge_size
+ self.spatial_merge_unit = self.spatial_merge_size**2
+ self.temporal_patch_size = vision_config.temporal_patch_size
+ self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
+
+ self.patch_embed = Qwen3_VisionPatchEmbed(
+ patch_size=self.patch_size,
+ temporal_patch_size=self.temporal_patch_size,
+ in_channels=vision_config.in_channels,
+ hidden_size=self.hidden_size,
+ )
+
+ self.pos_embed = nn.Embedding(self.num_position_embeddings,
+ self.hidden_size)
+
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
+ head_dim = self.hidden_size // self.num_heads
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([
+ Qwen3_VisionBlock(
+ dim=self.hidden_size,
+ num_heads=self.num_heads,
+ mlp_hidden_dim=vision_config.intermediate_size,
+ act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ prefix=f"{prefix}.blocks.{layer_idx}")
+ for layer_idx in range(vision_config.depth)
+ ])
+
+ self.merger = Qwen3_VisionPatchMerger(
+ d_model=vision_config.out_hidden_size,
+ context_dim=self.hidden_size,
+ norm_layer=norm_layer,
+ spatial_merge_size=self.spatial_merge_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.merger",
+ )
+
+ self.deepstack_merger_list = nn.ModuleList([
+ Qwen3_VisionPatchMerger(
+ d_model=vision_config.out_hidden_size,
+ context_dim=self.hidden_size,
+ spatial_merge_size=self.spatial_merge_size,
+ use_postshuffle_norm=True,
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ prefix=f"{prefix}.deepstack_merger_list.{layer_idx}")
+ for layer_idx in range(len(self.deepstack_visual_indexes))
+ ])
+
+ self.attn_backend = get_vit_attn_backend(support_fa=True)
+ if self.attn_backend != _Backend.FLASH_ATTN and \
+ check_upstream_fa_availability(
+ torch.get_default_dtype()):
+ self.attn_backend = _Backend.FLASH_ATTN
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.patch_embed.proj.weight.dtype
+
+ @property
+ def device(self) -> torch.device:
+ return self.patch_embed.proj.weight.device
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(
+ torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def fast_pos_embed_interpolate(self, grid_thw):
+ num_grid_per_side = int(self.num_position_embeddings**0.5)
+
+ idx_list = [[] for _ in range(4)]
+ weight_list = [[] for _ in range(4)]
+
+ for t, h, w in grid_thw:
+ h_idxs = torch.linspace(0,
+ num_grid_per_side - 1,
+ h,
+ dtype=torch.float32)
+ w_idxs = torch.linspace(0,
+ num_grid_per_side - 1,
+ w,
+ dtype=torch.float32)
+
+ h_idxs_floor = h_idxs.to(torch.long)
+ w_idxs_floor = w_idxs.to(torch.long)
+ h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1,
+ max=num_grid_per_side - 1)
+ w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1,
+ max=num_grid_per_side - 1)
+
+ dh = h_idxs - h_idxs_floor
+ dw = w_idxs - w_idxs_floor
+
+ idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T +
+ w_idxs_floor[None]).flatten().tolist() * t)
+ idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T +
+ w_idxs_ceil[None]).flatten().tolist() * t)
+ idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
+ w_idxs_floor[None]).flatten().tolist() * t)
+ idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T +
+ w_idxs_ceil[None]).flatten().tolist() * t)
+
+ weight_list[0].extend(
+ ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t)
+ weight_list[1].extend(
+ ((1 - dh)[None].T * dw[None]).flatten().tolist() * t)
+ weight_list[2].extend(
+ (dh[None].T * (1 - dw)[None]).flatten().tolist() * t)
+ weight_list[3].extend(
+ (dh[None].T * dw[None]).flatten().tolist() * t)
+
+ device = self.pos_embed.weight.device
+ dtype = self.pos_embed.weight.dtype
+
+ p0 = self.pos_embed(
+ torch.tensor(
+ idx_list[0], dtype=torch.long, device=device)) * torch.tensor(
+ weight_list[0], dtype=dtype, device=device)[:, None]
+ p1 = self.pos_embed(
+ torch.tensor(
+ idx_list[1], dtype=torch.long, device=device)) * torch.tensor(
+ weight_list[1], dtype=dtype, device=device)[:, None]
+ p2 = self.pos_embed(
+ torch.tensor(
+ idx_list[2], dtype=torch.long, device=device)) * torch.tensor(
+ weight_list[2], dtype=dtype, device=device)[:, None]
+ p3 = self.pos_embed(
+ torch.tensor(
+ idx_list[3], dtype=torch.long, device=device)) * torch.tensor(
+ weight_list[3], dtype=dtype, device=device)[:, None]
+
+ patch_pos_embeds = p0 + p1 + p2 + p3
+ patch_pos_embeds = patch_pos_embeds.split(
+ [t * h * w for t, h, w in grid_thw])
+ patch_pos_embeds_permute = []
+ m_size = self.spatial_merge_size
+ for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw):
+ pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size,
+ m_size, -1).permute(0, 1, 3, 2, 4,
+ 5).flatten(0, 4)
+ patch_pos_embeds_permute.append(pos_embed)
+ patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
+ return patch_pos_embeds
+
+ def compute_attn_mask_seqlen(
+ self,
+ cu_seqlens: torch.Tensor,
+ ) -> tuple[Optional[int], Optional[list[int]]]:
+ max_seqlen, seqlens = None, None
+ if self.attn_backend == _Backend.FLASH_ATTN:
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ elif self.attn_backend == _Backend.XFORMERS:
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ return max_seqlen, seqlens
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ grid_thw: list[list[int]],
+ ) -> torch.Tensor:
+ hidden_states = x.to(device=self.device, dtype=self.dtype)
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ cu_seqlens = torch.repeat_interleave(
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ dtype=grid_thw.dtype
+ if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ hidden_states = hidden_states.unsqueeze(1)
+ rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
+ max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+
+ deepstack_feature_lists = []
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(hidden_states,
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ max_seqlen=max_seqlen,
+ seqlens=seqlens)
+ if layer_num in self.deepstack_visual_indexes:
+ deepstack_merger_idx = self.deepstack_visual_indexes.index(
+ layer_num)
+ deepstack_feature = self.deepstack_merger_list[
+ deepstack_merger_idx](hidden_states)
+ deepstack_feature_lists.append(deepstack_feature)
+ hidden_states = self.merger(hidden_states)
+ hidden_states = torch.cat(
+ [hidden_states] + deepstack_feature_lists,
+ dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("attn.qkv.", "attn.q.", "q"),
+ ("attn.qkv.", "attn.k.", "k"),
+ ("attn.qkv.", "attn.v.", "v"),
+ ]
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class Qwen3_VisionTransformerStaticShape(Qwen3_VisionTransformer):
+ """
+ Here we overwrite some of the methods of Qwen3_VisionTransformer
+ to make the model more friendly to static shapes. Specifically,
+ we split the forward method into:
+ - pre_attn (dynamic)
+ - forward (static shape)
+ and we should call get_image_embeds instead of forward, allowing
+ the forward method ro run with HPU_Graphs, whereas the
+ pre_attn and post_attn methods are allow to be dynamic.
+ """
+
+ def pad_multimodal_data(self,
+ pixel_values,
+ vision_buckets,
+ constant_value=0):
+
+ desired_number_of_pixels = vision_buckets.get_multimodal_bucket(
+ pixel_values.shape[0])
+ padding_len = desired_number_of_pixels - pixel_values.shape[0]
+ if padding_len <= 0:
+ return pixel_values
+
+ logger_msg = "Padding current number pixel " \
+ + str(pixel_values.shape[0]) \
+ + " to " \
+ + str(desired_number_of_pixels)
+ logger.debug(logger_msg)
+
+ pixel_values = F.pad(pixel_values, (0, 0, 0, padding_len), "constant",
+ constant_value)
+
+ return pixel_values
+
+ def pre_attn(self, x: torch.Tensor, grid_thw: torch.Tensor,
+ vision_buckets):
+ hidden_states = x.to(device=self.device, dtype=self.dtype)
+ hidden_states = self.patch_embed(hidden_states)
+
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
+ hidden_states = hidden_states + pos_embeds
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+ attention_mask = torch.ones(hidden_states.size(0),
+ 1).to(device=self.device)
+
+ hidden_states = self.pad_multimodal_data(hidden_states, vision_buckets,
+ 0)
+ rotary_pos_emb = self.pad_multimodal_data(rotary_pos_emb,
+ vision_buckets, -100)
+ attention_mask = self.pad_multimodal_data(attention_mask,
+ vision_buckets, 0)
+
+ return hidden_states, rotary_pos_emb, attention_mask
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ rotary_pos_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ hidden_states = x.unsqueeze(1)
+ rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
+ deepstack_feature_lists = []
+ for layer_num, blk in enumerate(self.blocks):
+ hidden_states = blk(hidden_states,
+ rotary_pos_emb=rotary_pos_emb,
+ attn_mask=attn_mask,
+ cu_seqlens=None)
+ if layer_num in self.deepstack_visual_indexes:
+ deepstack_merger_idx = self.deepstack_visual_indexes.index(
+ layer_num)
+ deepstack_feature = self.deepstack_merger_list[
+ deepstack_merger_idx](hidden_states)
+ deepstack_feature_lists.append(deepstack_feature)
+ hidden_states = self.merger(hidden_states)
+ hidden_states = torch.cat(
+ [hidden_states] + deepstack_feature_lists,
+ dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
+ return hidden_states
+
+ def get_image_embeds(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: torch.Tensor,
+ vision_buckets,
+ ) -> torch.Tensor:
+
+ offset = 0
+ results = []
+ # process each image one by one
+ for img_idx in range(grid_thw.shape[0]):
+ img_shape = grid_thw[img_idx, :].unsqueeze(0).clone()
+ # For video, we process frames separately
+ grid_t = grid_thw[img_idx, 0]
+ img_shape[0, 0] = 1
+ curr_img_size = img_shape.prod()
+ for _ in torch.arange(0, grid_t):
+ pixel_values_curr_img = pixel_values[offset:offset +
+ curr_img_size, :]
+
+ offset += curr_img_size
+
+ (pixel_values_curr_img_padded, rot_pos_emb,
+ attention_mask) = self.pre_attn(pixel_values_curr_img,
+ img_shape, vision_buckets)
+
+ fullatt_block_attn_mask = \
+ attention_mask.squeeze(1).unsqueeze(0) * attention_mask
+
+ extra_forward_kwargs = {}
+ if htorch.utils.internal.is_lazy():
+ padded_len = pixel_values_curr_img_padded.shape[0]
+ use_graph = vision_buckets.use_graph(padded_len)
+ extra_forward_kwargs.update(
+ {"bypass_hpu_graphs": not use_graph})
+
+ htcore.mark_step()
+ hidden_states = self.forward(pixel_values_curr_img_padded,
+ rotary_pos_emb=rot_pos_emb,
+ attn_mask=fullatt_block_attn_mask,
+ **extra_forward_kwargs)
+ htcore.mark_step()
+
+ post_embed_size = curr_img_size // self.spatial_merge_unit
+ results += [hidden_states[:post_embed_size, :]]
+
+ results_cat = torch.concat(results)
+ image_embeds = results_cat
+ return image_embeds
+
+
+class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(Qwen3VLConfig)
+
+ def get_hf_processor(self, **kwargs: object) -> Qwen3VLProcessor:
+ return self.ctx.get_hf_processor(
+ Qwen3VLProcessor,
+ use_fast=kwargs.pop("use_fast", True),
+ **kwargs,
+ )
+
+ def get_tokenizer(self):
+ return self.ctx.tokenizer
+
+ def get_image_processor(self,
+ **kwargs: object) -> Qwen2VLImageProcessorFast:
+ return self.get_hf_processor(**kwargs).image_processor
+
+ def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
+ return self.get_hf_processor(**kwargs).video_processor
+
+ def _get_vision_info(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ num_frames: int = 2,
+ do_resize: bool = True,
+ image_processor: Optional[Qwen2VLImageProcessorFast],
+ ) -> tuple[ImageSize, int]:
+ if image_processor is None:
+ image_processor = self.get_image_processor()
+
+ hf_config = self.get_hf_config()
+ vision_config = hf_config.vision_config
+ patch_size = vision_config.patch_size
+ merge_size = vision_config.spatial_merge_size
+ temporal_patch_size = vision_config.temporal_patch_size
+
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height=image_height,
+ width=image_width,
+ factor=patch_size * merge_size,
+ min_pixels=image_processor.size["shortest_edge"],
+ max_pixels=image_processor.size["longest_edge"],
+ )
+ preprocessed_size = ImageSize(width=resized_width,
+ height=resized_height)
+ else:
+ preprocessed_size = ImageSize(width=image_width,
+ height=image_height)
+
+ padded_num_frames = num_frames + num_frames % temporal_patch_size
+
+ grid_t = max(padded_num_frames // temporal_patch_size, 1)
+ grid_h = preprocessed_size.height // patch_size
+ grid_w = preprocessed_size.width // patch_size
+
+ num_patches = grid_t * grid_h * grid_w
+ num_vision_tokens = num_patches // (merge_size**2)
+
+ return preprocessed_size, num_vision_tokens
+
+ def _calculate_timestamps(self, indices: list[int] | torch.Tensor,
+ video_fps: float, merge_size: int):
+ if not isinstance(indices, list):
+ indices = indices.tolist()
+ if len(indices) % merge_size != 0:
+ # don't update metadata's frames_indices directly
+ indices = indices + [indices[-1]
+ ] * (merge_size - len(indices) % merge_size)
+ timestamps = [idx / video_fps for idx in indices]
+ timestamps = [(timestamps[i] + timestamps[i + merge_size - 1]) / 2
+ for i in range(0, len(timestamps), merge_size)]
+ return timestamps
+
+ def _get_video_second_idx(
+ self,
+ metadata: dict[str, Any],
+ do_sample_frames: Optional[bool] = None,
+ sampled_fps: Optional[float] = None) -> list[int]:
+ video_processor = self.get_video_processor()
+ merge_size = video_processor.merge_size
+ indices = metadata["frames_indices"]
+
+ # metadata["fps"] refers to the true fps of the input video.
+ video_fps = metadata["fps"]
+ if do_sample_frames is None:
+ do_sample_frames = metadata.get("do_sample_frames", False)
+
+ # If video frames are sampled in HF processor (instead of vLLM
+ # video loader), we need to re-calculate the indices from original
+ # metadata.
+ if do_sample_frames:
+ # here video_fps is the fps of the sampled video, and
+ # metadata["fps"] refers to the fps of the original video.
+ video_fps = sampled_fps if sampled_fps else video_processor.fps
+ total_num_frames = metadata["total_num_frames"]
+ num_frames = int(total_num_frames / metadata["fps"] * video_fps)
+ num_frames = min(
+ min(max(num_frames, video_processor.min_frames),
+ video_processor.max_frames), total_num_frames)
+ indices = np.linspace(0, total_num_frames - 1,
+ num_frames).round().astype(int).tolist()
+ timestamps = self._calculate_timestamps(indices, video_fps, merge_size)
+ return timestamps
+
+
+class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ num_videos = mm_counts.get("video", 0)
+
+ image_token = "<|vision_start|><|image_pad|><|vision_end|>"
+ video_token = "<|vision_start|><|video_pad|><|vision_end|>"
+
+ return image_token * num_images + video_token * num_videos
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 0)
+ num_videos = mm_counts.get("video", 0)
+
+ target_width, target_height = (
+ self.info.get_image_size_with_most_features())
+ target_num_frames = self.info.get_num_frames_with_most_features(
+ seq_len, mm_counts)
+ return {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images),
+ "video":
+ self._get_dummy_videos(
+ width=target_width,
+ height=target_height,
+ num_frames=target_num_frames,
+ num_videos=num_videos,
+ ),
+ }
+
+ def _get_dummy_videos(
+ self,
+ *,
+ width: int,
+ height: int,
+ num_frames: int,
+ num_videos: int,
+ ) -> list[VideoItem]:
+ num_frames = max(num_frames, 2)
+ video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
+ video_items = []
+ for i in range(num_videos):
+ video_metadata = {
+ "fps": 2.0,
+ "duration": num_frames / 2.0,
+ "total_num_frames": num_frames,
+ "frames_indices": [i for i in range(num_frames)],
+ "video_backend": "opencv",
+ "do_sample_frames": False,
+ }
+ video_item = (video.copy(), video_metadata)
+ video_items.append(video_item)
+ return video_items
+
+
+class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]
+ ):
+
+ def _get_data_parser(self) -> MultiModalDataParser:
+ return MultiModalDataParser(video_needs_metadata=True)
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ mm_data = dict(mm_data)
+ processor = self.info.get_hf_processor(**mm_kwargs)
+
+ # Separate video processing from image processing. Because the videos
+ # are processed into serval image patches
+ if ("videos" in mm_data and isinstance(mm_data["videos"], list)
+ and len(mm_data["videos"]) > 0):
+ video_grid_thw_lst = []
+ pixel_values_videos_lst = []
+
+ for item_idx, item in enumerate(mm_data.pop("videos", [])):
+ video_array, metadata = item
+
+ # NOTE: @JJJYmmm new attr metadata.frames_indices indicates
+ # the sampled frames indices of pre-sampled videos, which is
+ # used to calculate the timestamps. Make sure that
+ # do_sample_frames in mm_kwargs is false for presampled videos.
+
+ # NOTE: a copy of is created to update do_sample_frames,
+ # otherwise mm_hash for the object will be incorrect.
+ video_mm_kwargs = dict(**mm_kwargs)
+ if "do_sample_frames" not in video_mm_kwargs:
+ # qwen_vl_utils already has "do_sample_frames" in
+ # mm_kwargs, don't overwrite it.
+ video_mm_kwargs["do_sample_frames"] = metadata.get(
+ "do_sample_frames", False)
+
+ metadata = VideoMetadata(**{
+ k: metadata[k]
+ for k in metadata if k != "do_sample_frames"
+ })
+
+ video_mm_data = dict()
+ video_mm_data["videos"] = [[video_array]]
+ video_mm_data["video_metadata"] = [[metadata]]
+
+ video_outputs = super()._call_hf_processor(
+ prompt="<|vision_start|><|video_pad|><|vision_end|>",
+ mm_data=video_mm_data,
+ mm_kwargs=video_mm_kwargs,
+ )
+ input_ids = video_outputs.pop("input_ids")
+ video_placeholder = processor.tokenizer.batch_decode(
+ input_ids)[0]
+ prompt = prompt.replace(
+ "<|vision_start|><|video_pad|><|vision_end|>",
+ video_placeholder,
+ 1,
+ )
+
+ video_grid_thw_lst.append(video_outputs["video_grid_thw"])
+ pixel_values_videos_lst.append(
+ video_outputs["pixel_values_videos"])
+ video_outputs = dict(
+ pixel_values_videos=torch.cat(pixel_values_videos_lst),
+ video_grid_thw=torch.cat(video_grid_thw_lst),
+ )
+ else:
+ video_outputs = dict()
+
+ processed_outputs = super()._call_hf_processor(
+ prompt=prompt,
+ mm_data=mm_data,
+ mm_kwargs=mm_kwargs,
+ )
+ combined_outputs = dict(
+ processed_outputs,
+ **video_outputs,
+ )
+ return BatchFeature(combined_outputs)
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
+ image_grid_sizes = image_grid_thw.prod(-1)
+
+ video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
+ video_grid_sizes = video_grid_thw.prod(-1)
+
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_grid_sizes),
+ image_embeds=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_grid_sizes),
+ image_grid_thw=MultiModalFieldConfig.batched("image"),
+ pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
+ "video", video_grid_sizes),
+ video_embeds=MultiModalFieldConfig.flat_from_sizes(
+ "video", video_grid_sizes),
+ video_grid_thw=MultiModalFieldConfig.batched("video"),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_processor = self.info.get_image_processor(
+ **hf_processor_mm_kwargs)
+ tokenizer = self.info.get_tokenizer()
+ hf_config = self.info.get_hf_config()
+
+ video_token_id = hf_config.video_token_id
+ vision_start_token_id = hf_config.vision_start_token_id
+ vision_end_token_id = hf_config.vision_end_token_id
+
+ merge_length = image_processor.merge_size**2
+
+ def get_image_replacement_qwen3vl(item_idx: int):
+ grid_thw = out_mm_kwargs["image_grid_thw"][item_idx]
+ assert isinstance(grid_thw, torch.Tensor)
+
+ num_tokens = int(grid_thw.prod()) // merge_length
+ return [hf_processor.image_token_id] * num_tokens
+
+ def get_video_replacement_qwen3vl(item_idx: int):
+ grid_thw = out_mm_kwargs["video_grid_thw"][item_idx]
+ assert isinstance(grid_thw, torch.Tensor)
+
+ video, metadata = mm_items["video"][item_idx]
+ do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
+ sampled_fps = hf_processor_mm_kwargs.get("fps")
+ if is_list_of(sampled_fps, float):
+ sampled_fps = sampled_fps[item_idx]
+ timestamps = self.info._get_video_second_idx(
+ metadata, do_sample_frames, sampled_fps)
+
+ assert len(timestamps) == grid_thw[0], (
+ f"The timestamps length({len(timestamps)}) should be equal "
+ f"video length ({grid_thw[0]}).")
+
+ frames_idx_token = [
+ tokenizer.encode(f"<{curr_time:.1f} seconds>",
+ add_special_tokens=False)
+ for curr_time in timestamps
+ ]
+ num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
+ placeholder = []
+ for frame_idx in frames_idx_token:
+ placeholder.extend(frame_idx)
+ placeholder.extend([vision_start_token_id] +
+ [video_token_id] * num_tokens_per_frame +
+ [vision_end_token_id])
+ return PromptUpdateDetails.select_token_id(placeholder,
+ video_token_id)
+
+ return [
+ PromptReplacement(
+ modality="image",
+ target=hf_processor.image_token,
+ replacement=get_image_replacement_qwen3vl,
+ ),
+
+ # NOTE: We match string on purpose since searching sequence of
+ # token ids takes more time.
+ PromptReplacement(
+ modality="video",
+ target="<|vision_start|><|video_pad|><|vision_end|>",
+ replacement=get_video_replacement_qwen3vl,
+ ),
+ ]
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ # the same shape as input_embeds
+ "deepstack_input_embeds": 0
+ })
+class Qwen3LLMModel(Qwen3Model):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ if not get_pp_group().is_first_rank:
+ assert self.start_layer >= len(
+ vllm_config.model_config.hf_config.vision_config.
+ deepstack_visual_indexes), (
+ "start_layer should be greater than or equal to "
+ "len(deepstack_visual_indexes)")
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ # args for deepstack
+ deepstack_input_embeds: Optional[IntermediateTensors] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+ if is_hpu:
+ htorch.core.mark_step()
+ for layer_idx, layer in enumerate(
+ self.layers[self.start_layer:self.end_layer]):
+ layer_idx = layer_idx + self.start_layer
+
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ residual,
+ )
+
+ if deepstack_input_embeds is not None and \
+ layer_idx in range(0, len(deepstack_input_embeds)):
+ hidden_states = hidden_states + deepstack_input_embeds[
+ f"deepstack_input_embeds_{layer_idx}"]
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super(Qwen3ForCausalLM, self).__init__()
+ config = vllm_config.model_config.hf_config.text_config
+ quant_config = vllm_config.quant_config
+ lora_config = vllm_config.lora_config
+
+ self.config = config
+ self.lora_config = lora_config
+
+ self.quant_config = quant_config
+ self.model = Qwen3LLMModel(vllm_config=vllm_config, prefix=prefix)
+
+ if get_pp_group().is_last_rank:
+ if config.tie_word_embeddings:
+ self.lm_head = self.model.embed_tokens
+ else:
+ self.lm_head = ParallelLMHead(config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix="lm_head")
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+
+
+@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
+ info=Qwen3VLProcessingInfo,
+ dummy_inputs=Qwen3VLDummyInputsBuilder)
+class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
+ SupportsLoRA, SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+ # To ensure correct weight loading and mapping.
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ "model.visual.": "visual.",
+ "lm_head.": "language_model.lm_head.",
+ "model.language_model.": "language_model.model.",
+ })
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
+ if modality.startswith("image"):
+ return "<|vision_start|><|image_pad|><|vision_end|>"
+ if modality.startswith("video"):
+ return "<|vision_start|><|video_pad|><|vision_end|>"
+
+ raise ValueError("Only image or video modality is supported")
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
+ super().__init__()
+ config: Qwen3VLConfig = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+ self.text_dim = config.text_config.hidden_size
+
+ if is_hpu:
+ qwen3_visionTransformer = Qwen3_VisionTransformerStaticShape
+ else:
+ qwen3_visionTransformer = Qwen3_VisionTransformer
+
+ self.visual = qwen3_visionTransformer(
+ config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=self._maybe_ignore_quant_config(quant_config),
+ prefix=maybe_prefix(prefix, "visual"),
+ )
+
+ self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
+ prefix=maybe_prefix(
+ prefix,
+ "language_model"))
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ self.use_deepstack = hasattr(config.vision_config,
+ 'deepstack_visual_indexes')
+ self.deepstack_num_level = len(
+ config.vision_config.deepstack_visual_indexes
+ ) if self.use_deepstack else 0
+ # register buffer for deepstack
+ self.deepstack_input_embeds = [
+ torch.zeros(1, vllm_config.scheduler_config.max_num_batched_tokens,
+ config.text_config.hidden_size)
+ for _ in range(self.deepstack_num_level)
+ ] if self.use_deepstack else None
+
+ def _get_deepstack_input_embeds(self) -> IntermediateTensors:
+ # get deepstack_input_embeds from buffer, and clear the buffer
+ return IntermediateTensors({
+ f"deepstack_input_embeds_{idx}":
+ self.deepstack_input_embeds[idx]
+ for idx in range(self.deepstack_num_level)
+ })
+
+ def _set_deepstack_input_embeds(
+ self, deepstack_input_embeds: torch.Tensor) -> None:
+ self.deepstack_input_embeds = [
+ deepstack_input_embeds[idx].clone()
+ for idx in range(self.deepstack_num_level)
+ ]
+
+ def _clear_deepstack_input_embeds(self, batch_size: int,
+ num_tokens: int) -> None:
+ # clear deepstack_input_embeds in buffer
+ if num_tokens > 0:
+ for idx in range(self.deepstack_num_level):
+ self.deepstack_input_embeds[
+ idx][:batch_size, :num_tokens].zero_()
+
+ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
+ # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
+ # seems to avoid vision encoder sections for some models.
+ if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
+ return None
+ return quant_config
+
+ def _validate_and_reshape_mm_tensor(self, mm_input: object,
+ name: str) -> torch.Tensor:
+ if not isinstance(mm_input, (torch.Tensor, list)):
+ raise ValueError(f"Incorrect type of {name}. "
+ f"Got type: {type(mm_input)}")
+ if isinstance(mm_input, torch.Tensor):
+ if mm_input.ndim == 2:
+ return mm_input
+ if mm_input.ndim != 3:
+ raise ValueError(f"{name} should be 2D or batched 3D tensor. "
+ f"Got ndim: {mm_input.ndim} "
+ f"(shape={mm_input.shape})")
+ return torch.concat(list(mm_input))
+ else:
+ return torch.concat(mm_input)
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ if pixel_values is not None:
+ pixel_values = self._validate_and_reshape_mm_tensor(
+ pixel_values, "image pixel values")
+ image_grid_thw = self._validate_and_reshape_mm_tensor(
+ image_grid_thw, "image grid_thw")
+
+ if not isinstance(pixel_values, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of image pixel values. "
+ f"Got type: {type(pixel_values)}")
+
+ return Qwen2_5_VLImagePixelInputs(type="pixel_values",
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw)
+
+ if image_embeds is not None:
+ image_embeds = self._validate_and_reshape_mm_tensor(
+ image_embeds, "image embeds")
+ image_grid_thw = self._validate_and_reshape_mm_tensor(
+ image_grid_thw, "image grid_thw")
+
+ if not isinstance(image_embeds, torch.Tensor):
+ raise ValueError("Incorrect type of image embeddings. "
+ f"Got type: {type(image_embeds)}")
+ return Qwen2_5_VLImageEmbeddingInputs(
+ type="image_embeds",
+ image_embeds=image_embeds,
+ image_grid_thw=image_grid_thw)
+
+ def _parse_and_validate_video_input(
+ self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]:
+ pixel_values_videos = kwargs.pop("pixel_values_videos", None)
+ video_embeds = kwargs.pop("video_embeds", None)
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
+ second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)
+
+ if pixel_values_videos is None and video_embeds is None:
+ return None
+
+ if pixel_values_videos is not None:
+ pixel_values_videos = self._validate_and_reshape_mm_tensor(
+ pixel_values_videos, "video pixel values")
+ video_grid_thw = self._validate_and_reshape_mm_tensor(
+ video_grid_thw, "video grid_thw")
+
+ return Qwen2_5_VLVideoPixelInputs(
+ type="pixel_values_videos",
+ pixel_values_videos=pixel_values_videos,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ )
+
+ if video_embeds is not None:
+ video_embeds = self._validate_and_reshape_mm_tensor(
+ video_embeds, "video embeds")
+ video_grid_thw = self._validate_and_reshape_mm_tensor(
+ video_grid_thw, "video grid_thw")
+
+ if not isinstance(video_embeds, torch.Tensor):
+ raise ValueError("Incorrect type of video embeddings. "
+ f"Got type: {type(video_embeds)}")
+ return Qwen2_5_VLVideoEmbeddingInputs(
+ type="video_embeds",
+ video_embeds=video_embeds,
+ video_grid_thw=video_grid_thw)
+
+ def _process_image_input(
+ self,
+ image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
+
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+ grid_thw_list = grid_thw.tolist()
+
+ if image_input["type"] == "image_embeds":
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values = image_input["pixel_values"].type(self.visual.dtype)
+ if is_hpu:
+ assert isinstance(self.visual,
+ Qwen3_VisionTransformerStaticShape)
+ image_embeds = self.visual.get_image_embeds(
+ pixel_values,
+ grid_thw=grid_thw,
+ vision_buckets=self.vision_buckets,
+ )
+ else:
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
+
+ # Split concatenated embeddings for each image item.
+ # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
+ merge_size = self.visual.spatial_merge_size
+ sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
+ (merge_size * merge_size)).tolist()
+ return image_embeds.split(sizes)
+
+ def _process_video_input(
+ self,
+ video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]:
+
+ grid_thw = video_input["video_grid_thw"]
+ assert grid_thw.ndim == 2
+ grid_thw_list = grid_thw.tolist()
+
+ if video_input["type"] == "video_embeds":
+ video_embeds = video_input["video_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values_videos = video_input["pixel_values_videos"].type(
+ self.visual.dtype)
+ if is_hpu:
+ assert isinstance(self.visual,
+ Qwen3_VisionTransformerStaticShape)
+ video_embeds = self.visual.get_image_embeds(
+ pixel_values_videos,
+ grid_thw=grid_thw,
+ vision_buckets=self.vision_buckets,
+ )
+ else:
+ video_embeds = self.visual(pixel_values_videos,
+ grid_thw=grid_thw)
+
+ # Split concatenated embeddings for each video item.
+ # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
+ merge_size = self.visual.spatial_merge_size
+ sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) //
+ (merge_size * merge_size)).tolist()
+ return video_embeds.split(sizes)
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ mm_input_by_modality = {}
+ for input_key in kwargs:
+ if input_key in ("pixel_values", "image_embeds"
+ ) and "image" not in mm_input_by_modality:
+ mm_input_by_modality[
+ "image"] = self._parse_and_validate_image_input(**kwargs)
+ if input_key in ("pixel_values_videos", "video_embeds"
+ ) and "video" not in mm_input_by_modality:
+ mm_input_by_modality[
+ "video"] = self._parse_and_validate_video_input(**kwargs)
+ return mm_input_by_modality
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def get_multimodal_embeddings(
+ self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
+
+ mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
+ **kwargs)
+ if not mm_input_by_modality:
+ return None
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in mm_input_by_modality:
+ multimodal_input = mm_input_by_modality[modality]
+ if modality == "image":
+ vision_embeddings = self._process_image_input(multimodal_input)
+ multimodal_embeddings += vision_embeddings
+ if modality == "video":
+ video_embeddings = self._process_video_input(multimodal_input)
+ multimodal_embeddings += video_embeddings
+ return multimodal_embeddings
+
+ def _compute_deepstack_embeds(
+ self, input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
+ multimodal_embeddings: MultiModalEmbeddings) -> torch.Tensor:
+ visual_lens = [
+ x.shape[0] if isinstance(x, torch.Tensor) else len(x)
+ for x in multimodal_embeddings
+ ]
+ multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)
+
+ visual_dim = multimodal_embeddings_cat.shape[-1] // (
+ self.deepstack_num_level + 1)
+
+ main_dim, multi_dim = visual_dim, visual_dim * self.deepstack_num_level
+ multimodal_embeddings_main, multimodal_embeddings_multiscale = torch.split( # noqa:E501
+ multimodal_embeddings_cat, [main_dim, multi_dim],
+ dim=-1)
+
+ multimodal_embeddings = torch.split(multimodal_embeddings_main,
+ visual_lens,
+ dim=0)
+ multimodal_embeddings_multiscale = torch.split(
+ multimodal_embeddings_multiscale, visual_lens, dim=0)
+
+ deepstack_input_embeds = inputs_embeds.new_zeros(
+ inputs_embeds.size(0), inputs_embeds.size(1),
+ self.deepstack_num_level * inputs_embeds.size(1))
+
+ deepstack_input_embeds = merge_multimodal_embeddings(
+ input_ids,
+ deepstack_input_embeds,
+ multimodal_embeddings_multiscale,
+ placeholder_token_id=[
+ self.config.image_token_id, self.config.video_token_id
+ ],
+ )
+ deepstack_input_embeds = deepstack_input_embeds.view(
+ inputs_embeds.shape[0], inputs_embeds.shape[1],
+ self.deepstack_num_level, visual_dim).contiguous()
+ deepstack_input_embeds = deepstack_input_embeds.permute(
+ 2, 0, 1, 2).contiguous()
+ return deepstack_input_embeds, multimodal_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ deepstack_input_embeds = None
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None and self.use_deepstack:
+ deepstack_input_embeds, multimodal_embeddings = self._compute_deepstack_embeds( # noqa:E501
+ input_ids, inputs_embeds, multimodal_embeddings)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ [self.config.image_token_id, self.config.video_token_id])
+
+ # if self.use_deepstack:
+ # if deepstack_input_embeds is None:
+ # deepstack_input_embeds = torch.zeros_like(
+ # inputs_embeds).unsqueeze(0).repeat(
+ # self.deepstack_num_level, 1, 1, 1).contiguous()
+ # self._set_deepstack_input_embeds(deepstack_input_embeds)
+ if deepstack_input_embeds is None:
+ deepstack_input_embeds = torch.zeros(inputs_embeds.size(0),
+ inputs_embeds.size(1),
+ self.deepstack_num_level *
+ self.text_dim,
+ device=inputs_embeds.device,
+ dtype=inputs_embeds.dtype)
+ inputs_embeds = torch.cat((inputs_embeds, deepstack_input_embeds),
+ dim=-1)
+ return inputs_embeds
+
+ def get_input_embeddings_v0(
+ self,
+ input_ids: torch.Tensor,
+ image_input: Optional[Qwen2_5_VLImageInputs] = None,
+ video_input: Optional[Qwen2_5_VLVideoInputs] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.get_input_embeddings(
+ input_ids)[:, :, :self.text_dim]
+
+ if self.use_deepstack:
+ visual_dim = self.text_dim
+ deepstack_input_embeds = None
+ if image_input is not None or video_input is not None:
+ deepstack_input_embeds = torch.zeros_like(
+ inputs_embeds).unsqueeze(2).repeat(
+ 1, 1, self.deepstack_num_level, 1).flatten(2)
+
+ if image_input is not None:
+ image_embeds = self._process_image_input(image_input)
+ if self.use_deepstack:
+ image_embeds = torch.cat(image_embeds)
+
+ image_embeds, image_embeds_multiscale = image_embeds.split(
+ [visual_dim, visual_dim * self.deepstack_num_level],
+ dim=-1)
+
+ deepstack_input_embeds = merge_multimodal_embeddings(
+ input_ids,
+ deepstack_input_embeds,
+ image_embeds_multiscale,
+ placeholder_token_id=self.config.image_token_id,
+ )
+
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ image_embeds,
+ placeholder_token_id=self.config.image_token_id,
+ )
+
+ if video_input is not None:
+ video_embeds = self._process_video_input(video_input)
+ if self.use_deepstack:
+ video_embeds = torch.cat(video_embeds)
+
+ video_embeds, video_embeds_multiscale = video_embeds.split(
+ [visual_dim, visual_dim * self.deepstack_num_level],
+ dim=-1)
+
+ deepstack_input_embeds = merge_multimodal_embeddings(
+ input_ids,
+ deepstack_input_embeds,
+ video_embeds_multiscale,
+ placeholder_token_id=self.config.video_token_id,
+ )
+
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ video_embeds,
+ placeholder_token_id=self.config.video_token_id,
+ )
+
+ if self.use_deepstack and deepstack_input_embeds is not None:
+ # deepstack_input_embeds = deepstack_input_embeds.view(
+ # inputs_embeds.shape[0], inputs_embeds.shape[1],
+ # self.deepstack_num_level, visual_dim).permute(2, 0, 1,
+ # 3).contiguous()
+ # self._set_deepstack_input_embeds(deepstack_input_embeds)
+ inputs_embeds = torch.cat((inputs_embeds, deepstack_input_embeds),
+ dim=-1)
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ """Run forward pass for Qwen3VL.
+
+ Args:
+ input_ids: Flattened (concatenated) input_ids corresponding to a
+ batch.
+ positions: Flattened (concatenated) position ids corresponding to a
+ batch.
+ **NOTE**: If mrope is enabled (default setting for Qwen3VL
+ opensource models), the shape will be `(3, seq_len)`,
+ otherwise it will be `(seq_len,).
+ pixel_values: Pixel values to be fed to a model.
+ `None` if no images are passed.
+ image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
+ `None` if no images are passed.
+ pixel_values_videos: Pixel values of videos to be fed to a model.
+ `None` if no videos are passed.
+ video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
+ `None` if no videos are passed.
+ """
+
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner from
+ # `get_multimodal_embeddings` and `get_input_embeddings`, this
+ # condition is only for v0 compatibility.
+ elif inputs_embeds is None:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ video_input = self._parse_and_validate_video_input(**kwargs)
+
+ if image_input is None and video_input is None:
+ inputs_embeds = None
+ else:
+ if uses_mrope(self.config):
+ assert positions.ndim == 2 and positions.size(0) == 3, (
+ "multimodal section rotary embedding requires "
+ f"(3, seq_len) positions, but got {positions.size()}")
+ inputs_embeds = self.get_input_embeddings_v0(
+ input_ids,
+ image_input=image_input,
+ video_input=video_input)
+ input_ids = None
+
+ deepstack_input_embeds = None
+ if self.use_deepstack and inputs_embeds is not None and get_pp_group(
+ ).is_first_rank and inputs_embeds.size(2) > self.text_dim:
+ multiscale_len = len(self.visual.deepstack_visual_indexes)
+ input_embeds_multiscale = inputs_embeds[:, :, self.text_dim:]
+ inputs_embeds = inputs_embeds[:, :, :self.text_dim]
+ input_embeds_multiscale = input_embeds_multiscale.view(
+ inputs_embeds.shape[0], inputs_embeds.shape[1], multiscale_len,
+ self.text_dim).permute(2, 0, 1, 3).contiguous()
+ deepstack_input_embeds = IntermediateTensors({
+ f"deepstack_input_embeds_{idx}":
+ input_embeds_multiscale[idx]
+ for idx in range(self.deepstack_num_level)
+ })
+ # deepstack_input_embeds = self._get_deepstack_input_embeds()
+ else:
+ deepstack_input_embeds = None
+
+ hidden_states = self.language_model.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ # args for deepstack
+ deepstack_input_embeds=deepstack_input_embeds,
+ )
+
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model",
+ connector="model.visual.merger",
+ tower_model="model.visual.",
+ )
diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py
new file mode 100644
index 000000000000..04f87a078cf5
--- /dev/null
+++ b/vllm/model_executor/models/qwen3_vl_moe.py
@@ -0,0 +1,356 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright 2025 The vLLM team.
+# Copyright 2025 The Qwen Team.
+# Copyright 2025 The HuggingFace Inc. team.
+# All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Qwen3-VL-MoE model compatible with HuggingFace weights."""
+import typing
+from collections.abc import Iterable
+from typing import Callable, Optional, Union
+
+import torch
+from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import (
+ Qwen3VLMoeConfig)
+
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import VllmConfig
+from vllm.distributed import get_pp_group
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.platforms import current_platform
+from vllm.sequence import IntermediateTensors
+
+from .qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
+from .qwen3_vl import (Qwen3_VisionTransformer,
+ Qwen3_VisionTransformerStaticShape,
+ Qwen3VLDummyInputsBuilder,
+ Qwen3VLForConditionalGeneration,
+ Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
+from .utils import is_pp_missing_parameter, maybe_prefix
+
+logger = init_logger(__name__)
+is_hpu = current_platform.is_hpu()
+
+if is_hpu:
+ import habana_frameworks.torch as htorch
+
+
+class Qwen3VLMoeProcessingInfo(Qwen3VLProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(Qwen3VLMoeConfig)
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
+ # otherwise (seq_len, ).
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ # the same shape as input_embeds
+ "deepstack_input_embeds": 0
+ })
+class Qwen3MoeLLMModel(Qwen3MoeModel):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ if not get_pp_group().is_first_rank:
+ assert self.start_layer >= len(
+ vllm_config.model_config.hf_config.vision_config.
+ deepstack_visual_indexes), (
+ "start_layer should be greater than or equal to "
+ "len(deepstack_visual_indexes)")
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ deepstack_input_embeds: Optional[IntermediateTensors] = None,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+ if is_hpu:
+ htorch.core.mark_step()
+ for layer_idx, layer in enumerate(
+ self.layers[self.start_layer:self.end_layer]):
+ layer_idx = layer_idx + self.start_layer
+
+ hidden_states, residual = layer(
+ positions,
+ hidden_states,
+ residual,
+ )
+
+ if deepstack_input_embeds is not None and \
+ layer_idx in range(0, len(deepstack_input_embeds)):
+ hidden_states = hidden_states + deepstack_input_embeds[
+ f"deepstack_input_embeds_{layer_idx}"]
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+ def load_fused_expert_weights(self, name: str, params_dict: dict,
+ loaded_weight: torch.Tensor, shard_id: str,
+ num_experts: int):
+ param = params_dict[name]
+ weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
+ for expert_id in range(num_experts):
+ curr_expert_weight = loaded_weight[expert_id]
+ weight_loader(param, curr_expert_weight, name, shard_id, expert_id)
+ return True
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale",
+ ".v_scale", "_v_scale", ".weight_scale",
+ "_weight_scale", ".input_scale", "_input_scale")
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.num_experts)
+ is_fused_expert = False
+ fused_expert_params_mapping = [
+ ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
+ ("experts.w2_weight", "experts.down_proj", 0, "w2"),
+ ]
+ num_experts = self.config.num_experts
+ for name, loaded_weight in weights:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if ("experts.gate_up_proj" in name
+ or "experts.down_proj" in name):
+ is_fused_expert = True
+ expert_params_mapping = fused_expert_params_mapping
+
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if "mlp.experts" in name:
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ if name.endswith(ignore_suffixes) and name not in params_dict:
+ continue
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ if name.endswith("scale"):
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+ if name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ if weight_loader == default_weight_loader:
+ weight_loader(param, loaded_weight)
+ else:
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ is_expert_weight = False
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in name:
+ continue
+ # Anyway, this is an expert weight and should not be
+ # attempted to load as other weights later
+ is_expert_weight = True
+ name_mapped = name.replace(weight_name, param_name)
+ if is_fused_expert:
+ loaded_weight = loaded_weight.transpose(-1,
+ -2) # no bias
+ if "experts.gate_up_proj" in name:
+ loaded_weight = loaded_weight.chunk(2, dim=-2)
+ self.load_fused_expert_weights(
+ name_mapped, params_dict, loaded_weight[0],
+ "w1", num_experts)
+ self.load_fused_expert_weights(
+ name_mapped, params_dict, loaded_weight[1],
+ "w3", num_experts)
+ else:
+ # down_proj
+ self.load_fused_expert_weights(
+ name_mapped, params_dict, loaded_weight,
+ shard_id, num_experts)
+ else:
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+ # Skip loading extra parameters for GPTQ/modelopt models
+ if name_mapped.endswith(
+ ignore_suffixes
+ ) and name_mapped not in params_dict:
+ continue
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or
+ # not here since otherwise we may skip experts with
+ # other available replicas.
+ weight_loader = typing.cast(Callable[..., bool],
+ param.weight_loader)
+ weight_loader(param,
+ loaded_weight,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id)
+ name = name_mapped
+ break
+ else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+ # Skip loading extra parameters for GPTQ/modelopt models.
+ if name.endswith(
+ ignore_suffixes) and name not in params_dict:
+ continue
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ # Remapping the name of FP8 kv-scale.
+ if name.endswith("kv_scale"):
+ remapped_kv_scale_name = name.replace(
+ ".kv_scale", ".attn.kv_scale")
+ if remapped_kv_scale_name not in params_dict:
+ logger.warning_once(
+ "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
+ name,
+ remapped_kv_scale_name,
+ )
+ continue
+ else:
+ name = remapped_kv_scale_name
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super(Qwen3MoeForCausalLM, self).__init__()
+ self.config = vllm_config.model_config.hf_config.text_config
+ self.quant_config = vllm_config.quant_config
+ self.model = Qwen3MoeLLMModel(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+ self.lm_head = ParallelLMHead(self.config.vocab_size,
+ self.config.hidden_size,
+ quant_config=self.quant_config)
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
+ self.logits_processor = LogitsProcessor(self.config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+
+
+@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
+ info=Qwen3VLMoeProcessingInfo,
+ dummy_inputs=Qwen3VLDummyInputsBuilder)
+class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super(Qwen3VLForConditionalGeneration, self).__init__()
+ config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+ self.text_dim = config.text_config.hidden_size
+
+ if is_hpu:
+ qwen3_visionTransformer = Qwen3_VisionTransformerStaticShape
+ else:
+ qwen3_visionTransformer = Qwen3_VisionTransformer
+
+ self.visual = qwen3_visionTransformer(
+ config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=self._maybe_ignore_quant_config(quant_config),
+ prefix=maybe_prefix(prefix, "visual"),
+ )
+
+ self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
+ prefix=maybe_prefix(
+ prefix,
+ "language_model"))
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ self.use_deepstack = hasattr(config.vision_config,
+ 'deepstack_visual_indexes')
+ self.deepstack_num_level = len(
+ config.vision_config.deepstack_visual_indexes
+ ) if self.use_deepstack else 0
+ # register buffer for deepstack
+ self.deepstack_input_embeds = [
+ torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
+ config.text_config.hidden_size)
+ for _ in range(self.deepstack_num_level)
+ ] if self.use_deepstack else None
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 47b191007ee1..4ebd5e1638ce 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -220,9 +220,13 @@
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
- "UltravoxModel": ("ultravox", "UltravoxModel"),
+ "Qwen3OmniMoeForConditionalGeneration": ("qwen3_omni_moe_thinker", "Qwen3OmniMoeThinkerForConditionalGeneration"), # noqa: E501
+ "Qwen3OmniMoeModel": ("qwen3_omni_moe_thinker", "Qwen3OmniMoeThinkerForConditionalGeneration"), # noqa: E501
+ "Qwen3VLForConditionalGeneration": ("qwen3_vl", "Qwen3VLForConditionalGeneration"), # noqa: E501
+ "Qwen3VLMoeForConditionalGeneration": ("qwen3_vl_moe", "Qwen3VLMoeForConditionalGeneration"), # noqa: E501
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
+ "UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py
index b56ad407ce55..fdcd04bdd6d3 100644
--- a/vllm/multimodal/inputs.py
+++ b/vllm/multimodal/inputs.py
@@ -7,11 +7,11 @@
from dataclasses import dataclass
from functools import partial
from itertools import accumulate
-from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
- Union, cast, final)
+from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union,
+ cast, final)
import numpy as np
-from typing_extensions import NotRequired, TypeAlias
+from typing_extensions import NotRequired, TypeAlias, TypeVar
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.utils import LazyLoader, full_groupby, is_list_of
@@ -601,6 +601,85 @@ def modality(self) -> str:
return next(iter(modalities))
+_I = TypeVar(
+ "_I",
+ MultiModalKwargsItem,
+ Optional[MultiModalKwargsItem],
+ default=MultiModalKwargsItem,
+)
+
+
+class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
+ """
+ A dictionary of
+ [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
+ by modality.
+ """
+
+ @staticmethod
+ def from_hf_inputs(
+ hf_inputs: "BatchFeature",
+ config_by_key: Mapping[str, MultiModalFieldConfig],
+ ):
+ # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
+ # We assume that those fields are not used in vLLM
+ elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
+ keys_by_modality = defaultdict[str, set[str]](set)
+ for key, config in config_by_key.items():
+ batch = hf_inputs.get(key)
+ if batch is not None:
+ elems = config.build_elems(key, batch)
+ if len(elems) > 0:
+ elems_by_key[key] = elems
+ keys_by_modality[config.modality].add(key)
+
+ items = list[MultiModalKwargsItem]()
+ for modality, keys in keys_by_modality.items():
+ elems_in_modality = {k: elems_by_key[k] for k in keys}
+ batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
+
+ if len(set(batch_sizes.values())) > 1:
+ raise ValueError(
+ f"Cannot merge different batch sizes for {modality=}! "
+ f"Found: {batch_sizes=}")
+
+ batch_size = next(iter(batch_sizes.values()))
+ for item_idx in range(batch_size):
+ elems = [v[item_idx] for v in elems_in_modality.values()]
+ items.append(MultiModalKwargsItem.from_elems(elems))
+
+ return MultiModalKwargsItems.from_seq(items)
+
+ @staticmethod
+ def from_seq(items: Sequence[MultiModalKwargsItem]):
+ items_by_modality = full_groupby(items, key=lambda x: x.modality)
+ return MultiModalKwargsItems(items_by_modality)
+
+ def __getitem__(self, modality: str) -> Sequence[_I]:
+ if modality not in self:
+ raise KeyError(f"Modality {modality!r} not found. "
+ f"Available modalities: {set(self.keys())}")
+
+ return super().__getitem__(modality) # type: ignore[return-value]
+
+ def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
+ elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
+ for modality, items in self.items():
+ for i, item in enumerate(items):
+ if item is None:
+ raise RuntimeError("Cannot build data from empty "
+ f"mm_items[{modality}][{i}]")
+
+ for key, elem in item.items():
+ elems_by_key[key].append(elem)
+
+ return MultiModalKwargs({
+ key:
+ elems[0].field.reduce_data(elems, pin_memory=pin_memory)
+ for key, elems in elems_by_key.items()
+ })
+
+
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py
index f15264cbde0b..903d3f49c6bd 100644
--- a/vllm/multimodal/video.py
+++ b/vllm/multimodal/video.py
@@ -146,9 +146,12 @@ def load_bytes(cls,
# Use transformers transformers.video_utils.VideoMetadata format
metadata = {
"total_num_frames": total_frames_num,
- "fps": original_fps,
- "duration": duration,
- "video_backend": "opencv"
+ "fps": num_frames / duration,
+ "video_backend": "opencv",
+ "frames_indices": list(range(num_frames)),
+ # extra field used to control hf processor's video
+ # sampling behavior
+ "do_sample_frames": num_frames == total_frames_num,
}
return frames, metadata
diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py
index 32b0c4d711b6..4e9a38b05f94 100644
--- a/vllm/worker/hpu_model_runner.py
+++ b/vllm/worker/hpu_model_runner.py
@@ -723,7 +723,9 @@ def compute_input_embeddings_for_mrope_mm_optimized(
input_ids = kwargs['input_ids']
with compile_only_mode_context_false():
if self.model_is_mrope:
- if self.model.config.model_type == 'qwen2_5_omni_thinker':
+ if self.model.config.model_type in [
+ 'qwen2_5_omni_thinker', 'qwen3_omni_moe_thinker'
+ ]:
multimodal_embeddings = \
self.model.get_multimodal_embeddings_v0(**kwargs)
inputs_embeds = self.model.get_input_embeddings_v0(
@@ -2801,9 +2803,15 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
image_h = img_args // 8
image_grid_thw = torch.tensor(
[[1, image_h, int(img_args / image_h)]])
+ embed_dim = 1176
+ if any([
+ model_type in self.get_model().config.model_type
+ for model_type in ['qwen3_vl', "qwen3_omni"]
+ ]):
+ embed_dim = 1536
pixel_values = torch.randn(
image_grid_thw[0].prod(),
- 1176) # TODO: figure out the variable name
+ embed_dim) # TODO: figure out the variable name
assert pixel_values.shape[0] % 64 == 0, (
f"pixel_values must be sliced in 64 chunks, "