From 76db627bee3aff66936a6bf95385eb7794626fbb Mon Sep 17 00:00:00 2001 From: William Zhang <133824995+2ez4bz@users.noreply.github.com> Date: Tue, 6 Jan 2026 15:00:44 -0800 Subject: [PATCH] [None][feat] EPD for Qwen3 VL * Why? We would like to support EPD disaggregated serving for Qwen3 VL. * What? This commit adds such support, and extends existing unit tests for correctness checks. Some minor (protected) interface changes had to be made to the weight mapper as a side-effect. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> --- .../models/checkpoints/base_weight_mapper.py | 11 +- .../checkpoints/hf/qwen3vl_weight_mapper.py | 19 +++ .../_torch/models/modeling_qwen3vl.py | 110 ++++++++++++++++-- .../multimodal/test_mm_encoder_standalone.py | 5 +- 4 files changed, 132 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py index 4d78b3dcb19..790be65eed5 100644 --- a/tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py @@ -29,9 +29,6 @@ def init_model_and_config(self, model: Union[nn.Module, raise ValueError("model must have a config attribute") self._tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size - self._head_dim = model.config.head_dim if hasattr( - model.config, 'head_dim' - ) and model.config.head_dim is not None else model.config.hidden_size // model.config.num_attention_heads self.map_weights() @@ -173,3 +170,11 @@ def model(self) -> Union[nn.Module, DecoderModelForCausalLM]: if self._model is None: raise RuntimeError("Weight mapper is not initialized") return self._model + + @property + def _head_dim(self) -> int: + model = self.model + head_dim = model.config.head_dim if hasattr( + model.config, 'head_dim' + ) and model.config.head_dim is not None else model.config.hidden_size // model.config.num_attention_heads + return head_dim diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/qwen3vl_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/qwen3vl_weight_mapper.py index 41b3da875ea..24a3602db94 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/qwen3vl_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/qwen3vl_weight_mapper.py @@ -1,3 +1,8 @@ +from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLTextConfig, + Qwen3VLVisionConfig, +) + from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper from tensorrt_llm._torch.models.modeling_utils import register_mapper @@ -6,3 +11,17 @@ class Qwen3VLHfWeightMapper(HfWeightMapper): def preprocess_weights(self, weights: dict) -> dict: return weights + + @property + def _head_dim(self) -> int: + config = self.model.config + if (head_dim := getattr(config, "head_dim", None)) is not None: + return head_dim + if isinstance(config, Qwen3VLTextConfig): + num_heads = config.num_attention_heads + elif isinstance(config, Qwen3VLVisionConfig): + num_heads = config.num_heads + else: + raise TypeError(f"Unexpected config class {type(config).__name__}.") + + return config.hidden_size // num_heads diff --git a/tensorrt_llm/_torch/models/modeling_qwen3vl.py b/tensorrt_llm/_torch/models/modeling_qwen3vl.py index f89d801f94d..d073f6745b7 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3vl.py @@ -25,6 +25,7 @@ MultimodalPlaceholderPlacement, TextPrompt, register_input_processor, + support_multimodal_disaggregated, ) from ...inputs.multimodal import MultimodalParams from ...logger import logger @@ -350,6 +351,85 @@ def __call__( "multimodal_data": multimodal_data, } + def get_prompt_token_ids( + self, inputs: TextPrompt, mm_handles: List[Dict[str, Any]] + ) -> Tuple[List[int], List[int], List[int]]: + """ + Build input token ids with multimodal placeholders expanded to the number of MM tokens. + + Args: + inputs: Text prompt input container. Must contain a non-empty prompt string. + mm_handles: List of multimodal embedding handles. Currently only a single handle is supported. + + Returns: + Tuple[List[int], List[int], List[int]]: + - expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token + - mm_token_length: per-image MM token lengths + - mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids + """ + # TODO: Move this function to the base input processor class when extending for more models + text_prompt = inputs.get("prompt") + if not text_prompt: + raise ValueError("Text prompt is required but not provided") + + if not isinstance(mm_handles, list): + raise TypeError("mm_handles must be a list") + + if len(mm_handles) > 1: + # TODO: only support single multimodal item within a request for now + raise NotImplementedError("Only one mm_handle is supported for Qwen3 VL for now") + + hidden_size = mm_handles[0]["tensor_size"][1] + num_deepstack_levels = len(self.config.vision_config.deepstack_visual_indexes) + # This is because, unlike previous Qwen VL models, the embeddings are concatenated with + # feature maps from deepstack layers. + expected_size = self.config.text_config.hidden_size * (1 + num_deepstack_levels) + if hidden_size != expected_size: + raise RuntimeError( + f"Expected multimodal embedding to have hidden size {expected_size}, got {hidden_size}." + ) + + input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids[0] + + # TODO: what about `video_token_id`? + image_token_index = self.config.image_token_id + + image_mask = input_ids == image_token_index + image_positions = torch.where(image_mask)[0] + num_images = len(image_positions) + assert num_images == len(mm_handles), "Number of images must match number of mm_handles" + total_mm_tokens = sum(mm_handle["tensor_size"][0] for mm_handle in mm_handles) + final_length = len(input_ids) - num_images + total_mm_tokens + # Create output tensor + expanded_ids = torch.empty(final_length, dtype=input_ids.dtype) + placeholder_id = self.tllm_multimodal_token_id + + # Fill the expanded sequence + write_pos = 0 + image_cnt = 0 + mm_token_length = [] + mm_token_offsets = [] + for read_pos in range(len(input_ids)): + if input_ids[read_pos] == image_token_index: + # Replace with placeholder id + mm_token_num = mm_handles[image_cnt]["tensor_size"][0] + expanded_ids[write_pos : write_pos + mm_token_num] = placeholder_id + mm_token_offsets.append(write_pos) + mm_token_length.append(mm_token_num) + write_pos += mm_token_num + image_cnt += 1 + else: + # Copy text token as-is + expanded_ids[write_pos] = input_ids[read_pos] + write_pos += 1 + + assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}" + assert mm_token_length[-1] + mm_token_offsets[-1] <= final_length, ( + f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less " + f"than or equal to final_length ({final_length})" + ) + return expanded_ids.to(torch.int32).tolist(), mm_token_length, mm_token_offsets + class Qwen3VLVisionAttention(Qwen2_5_VLVisionAttention): def __init__(self, model_config, layer_idx): @@ -825,6 +905,7 @@ def __init__( llm_model_config.pretrained_config.architectures = ["Qwen3MoeForCausalLM"] else: raise ValueError(f"Unsupported architecture: {self.original_arch}") + # Qwen3ForCausalLM. self.llm = AutoModelForCausalLM.from_config(llm_model_config) if not _is_disagg(): @@ -953,22 +1034,16 @@ def forward( # NOTE: Qwen*-VL series has mrope_config even on the text-only prompts, # so we need to separate the mm_multimodal_params from the text-only prompts. - mm_multimodal_params = [ - multimodal_param - for multimodal_param in multimodal_params - if multimodal_param.multimodal_data.get("image", {}).get("pixel_values") is not None - or multimodal_param.multimodal_data.get("video", {}).get("pixel_values_videos") - is not None - ] + mm_multimodal_params = self._get_requests_with_mm_data(multimodal_params) if len(mm_multimodal_params) > 0: if not _is_disagg(): mm_embeds = get_multimodal_embeddings( encoder_forward_fn=self.mm_encoder.forward, multimodal_params=mm_multimodal_params, ) - else: + elif not getattr(self, "support_mm_disagg", False): raise NotImplementedError( - "Qwen3VLModel does not support disaggregated inference yet. Please unset " + f"{type(self)} does not support disaggregated inference yet. Please unset " "the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'." ) mm_embeds = find_input_mm_embeds(mm_embeds, mm_multimodal_params) @@ -1008,7 +1083,24 @@ def forward( logger.debug(f"output shape: {output_prob.shape}") return output_prob + def _get_requests_with_mm_data(self, multimodal_params): + mm_multimodal_params = [] + for multimodal_param in multimodal_params: + data = multimodal_param.multimodal_data + if ( + # The first 2 conditions check whether there is input on which inference should be run. + data.get("image", {}).get("pixel_values") is not None + or data.get("video", {}).get("pixel_values_videos") is not None + # This condition corresponds to when the embeddings are already populated, as is e.g. + # the case in EPD disagg in the prefill worker. + or data.get("multimodal_embedding") + ): + mm_multimodal_params.append(multimodal_param) + + return mm_multimodal_params + +@support_multimodal_disaggregated @register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel) @register_auto_model("Qwen3VLForConditionalGeneration") @register_input_processor( diff --git a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py index 99154dd074a..993559879be 100644 --- a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py +++ b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py @@ -21,10 +21,12 @@ _LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf" _QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct" +_QWEN_3_VL_DIR = llm_models_root() / "Qwen3" / "Qwen3-VL-2B-Instruct" # TODO: Add multi-image in single chat test -@pytest.mark.parametrize("model_dir", [_LLAVA_DIR, _QWEN_2_5_VL_DIR]) +@pytest.mark.parametrize("model_dir", + [_LLAVA_DIR, _QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR]) @pytest.mark.parametrize("pd_disagg", [False, True]) def test_single_image_chat(model_dir, pd_disagg): """Test processing single image using encoder (pass mm_embeddings) + LLM API. @@ -180,6 +182,7 @@ def test_single_image_chat(model_dir, pd_disagg): # Qwen2.5 VL's vision encoder seems to output different embeddings based on this value. # The test only passes with this set to 1. (_QWEN_2_5_VL_DIR, 1), + (_QWEN_3_VL_DIR, 3), ], ) def test_multi_request_batch_chat(model_dir, encoder_max_batch_size):