Skip to content

Commit 76db627

Browse files
committed
[None][feat] EPD for Qwen3 VL
* Why? We would like to support EPD disaggregated serving for Qwen3 VL. * What? This commit adds such support, and extends existing unit tests for correctness checks. Some minor (protected) interface changes had to be made to the weight mapper as a side-effect. Signed-off-by: William Zhang <[email protected]>
1 parent 6b71b03 commit 76db627

File tree

4 files changed

+132
-13
lines changed

4 files changed

+132
-13
lines changed

tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ def init_model_and_config(self, model: Union[nn.Module,
2929
raise ValueError("model must have a config attribute")
3030

3131
self._tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size
32-
self._head_dim = model.config.head_dim if hasattr(
33-
model.config, 'head_dim'
34-
) and model.config.head_dim is not None else model.config.hidden_size // model.config.num_attention_heads
3532

3633
self.map_weights()
3734

@@ -173,3 +170,11 @@ def model(self) -> Union[nn.Module, DecoderModelForCausalLM]:
173170
if self._model is None:
174171
raise RuntimeError("Weight mapper is not initialized")
175172
return self._model
173+
174+
@property
175+
def _head_dim(self) -> int:
176+
model = self.model
177+
head_dim = model.config.head_dim if hasattr(
178+
model.config, 'head_dim'
179+
) and model.config.head_dim is not None else model.config.hidden_size // model.config.num_attention_heads
180+
return head_dim

tensorrt_llm/_torch/models/checkpoints/hf/qwen3vl_weight_mapper.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
2+
Qwen3VLTextConfig,
3+
Qwen3VLVisionConfig,
4+
)
5+
16
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper
27
from tensorrt_llm._torch.models.modeling_utils import register_mapper
38

@@ -6,3 +11,17 @@
611
class Qwen3VLHfWeightMapper(HfWeightMapper):
712
def preprocess_weights(self, weights: dict) -> dict:
813
return weights
14+
15+
@property
16+
def _head_dim(self) -> int:
17+
config = self.model.config
18+
if (head_dim := getattr(config, "head_dim", None)) is not None:
19+
return head_dim
20+
if isinstance(config, Qwen3VLTextConfig):
21+
num_heads = config.num_attention_heads
22+
elif isinstance(config, Qwen3VLVisionConfig):
23+
num_heads = config.num_heads
24+
else:
25+
raise TypeError(f"Unexpected config class {type(config).__name__}.")
26+
27+
return config.hidden_size // num_heads

tensorrt_llm/_torch/models/modeling_qwen3vl.py

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
MultimodalPlaceholderPlacement,
2626
TextPrompt,
2727
register_input_processor,
28+
support_multimodal_disaggregated,
2829
)
2930
from ...inputs.multimodal import MultimodalParams
3031
from ...logger import logger
@@ -350,6 +351,85 @@ def __call__(
350351
"multimodal_data": multimodal_data,
351352
}
352353

354+
def get_prompt_token_ids(
355+
self, inputs: TextPrompt, mm_handles: List[Dict[str, Any]]
356+
) -> Tuple[List[int], List[int], List[int]]:
357+
"""
358+
Build input token ids with multimodal placeholders expanded to the number of MM tokens.
359+
360+
Args:
361+
inputs: Text prompt input container. Must contain a non-empty prompt string.
362+
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.
363+
364+
Returns:
365+
Tuple[List[int], List[int], List[int]]:
366+
- expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token
367+
- mm_token_length: per-image MM token lengths
368+
- mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids
369+
"""
370+
# TODO: Move this function to the base input processor class when extending for more models
371+
text_prompt = inputs.get("prompt")
372+
if not text_prompt:
373+
raise ValueError("Text prompt is required but not provided")
374+
375+
if not isinstance(mm_handles, list):
376+
raise TypeError("mm_handles must be a list")
377+
378+
if len(mm_handles) > 1:
379+
# TODO: only support single multimodal item within a request for now
380+
raise NotImplementedError("Only one mm_handle is supported for Qwen3 VL for now")
381+
382+
hidden_size = mm_handles[0]["tensor_size"][1]
383+
num_deepstack_levels = len(self.config.vision_config.deepstack_visual_indexes)
384+
# This is because, unlike previous Qwen VL models, the embeddings are concatenated with
385+
# feature maps from deepstack layers.
386+
expected_size = self.config.text_config.hidden_size * (1 + num_deepstack_levels)
387+
if hidden_size != expected_size:
388+
raise RuntimeError(
389+
f"Expected multimodal embedding to have hidden size {expected_size}, got {hidden_size}."
390+
)
391+
392+
input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids[0]
393+
394+
# TODO: what about `video_token_id`?
395+
image_token_index = self.config.image_token_id
396+
397+
image_mask = input_ids == image_token_index
398+
image_positions = torch.where(image_mask)[0]
399+
num_images = len(image_positions)
400+
assert num_images == len(mm_handles), "Number of images must match number of mm_handles"
401+
total_mm_tokens = sum(mm_handle["tensor_size"][0] for mm_handle in mm_handles)
402+
final_length = len(input_ids) - num_images + total_mm_tokens
403+
# Create output tensor
404+
expanded_ids = torch.empty(final_length, dtype=input_ids.dtype)
405+
placeholder_id = self.tllm_multimodal_token_id
406+
407+
# Fill the expanded sequence
408+
write_pos = 0
409+
image_cnt = 0
410+
mm_token_length = []
411+
mm_token_offsets = []
412+
for read_pos in range(len(input_ids)):
413+
if input_ids[read_pos] == image_token_index:
414+
# Replace with placeholder id
415+
mm_token_num = mm_handles[image_cnt]["tensor_size"][0]
416+
expanded_ids[write_pos : write_pos + mm_token_num] = placeholder_id
417+
mm_token_offsets.append(write_pos)
418+
mm_token_length.append(mm_token_num)
419+
write_pos += mm_token_num
420+
image_cnt += 1
421+
else:
422+
# Copy text token as-is
423+
expanded_ids[write_pos] = input_ids[read_pos]
424+
write_pos += 1
425+
426+
assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}"
427+
assert mm_token_length[-1] + mm_token_offsets[-1] <= final_length, (
428+
f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less "
429+
f"than or equal to final_length ({final_length})"
430+
)
431+
return expanded_ids.to(torch.int32).tolist(), mm_token_length, mm_token_offsets
432+
353433

354434
class Qwen3VLVisionAttention(Qwen2_5_VLVisionAttention):
355435
def __init__(self, model_config, layer_idx):
@@ -825,6 +905,7 @@ def __init__(
825905
llm_model_config.pretrained_config.architectures = ["Qwen3MoeForCausalLM"]
826906
else:
827907
raise ValueError(f"Unsupported architecture: {self.original_arch}")
908+
# Qwen3ForCausalLM.
828909
self.llm = AutoModelForCausalLM.from_config(llm_model_config)
829910

830911
if not _is_disagg():
@@ -953,22 +1034,16 @@ def forward(
9531034

9541035
# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts,
9551036
# so we need to separate the mm_multimodal_params from the text-only prompts.
956-
mm_multimodal_params = [
957-
multimodal_param
958-
for multimodal_param in multimodal_params
959-
if multimodal_param.multimodal_data.get("image", {}).get("pixel_values") is not None
960-
or multimodal_param.multimodal_data.get("video", {}).get("pixel_values_videos")
961-
is not None
962-
]
1037+
mm_multimodal_params = self._get_requests_with_mm_data(multimodal_params)
9631038
if len(mm_multimodal_params) > 0:
9641039
if not _is_disagg():
9651040
mm_embeds = get_multimodal_embeddings(
9661041
encoder_forward_fn=self.mm_encoder.forward,
9671042
multimodal_params=mm_multimodal_params,
9681043
)
969-
else:
1044+
elif not getattr(self, "support_mm_disagg", False):
9701045
raise NotImplementedError(
971-
"Qwen3VLModel does not support disaggregated inference yet. Please unset "
1046+
f"{type(self)} does not support disaggregated inference yet. Please unset "
9721047
"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
9731048
)
9741049
mm_embeds = find_input_mm_embeds(mm_embeds, mm_multimodal_params)
@@ -1008,7 +1083,24 @@ def forward(
10081083
logger.debug(f"output shape: {output_prob.shape}")
10091084
return output_prob
10101085

1086+
def _get_requests_with_mm_data(self, multimodal_params):
1087+
mm_multimodal_params = []
1088+
for multimodal_param in multimodal_params:
1089+
data = multimodal_param.multimodal_data
1090+
if (
1091+
# The first 2 conditions check whether there is input on which inference should be run.
1092+
data.get("image", {}).get("pixel_values") is not None
1093+
or data.get("video", {}).get("pixel_values_videos") is not None
1094+
# This condition corresponds to when the embeddings are already populated, as is e.g.
1095+
# the case in EPD disagg in the prefill worker.
1096+
or data.get("multimodal_embedding")
1097+
):
1098+
mm_multimodal_params.append(multimodal_param)
1099+
1100+
return mm_multimodal_params
1101+
10111102

1103+
@support_multimodal_disaggregated
10121104
@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel)
10131105
@register_auto_model("Qwen3VLForConditionalGeneration")
10141106
@register_input_processor(

tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121

2222
_LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf"
2323
_QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct"
24+
_QWEN_3_VL_DIR = llm_models_root() / "Qwen3" / "Qwen3-VL-2B-Instruct"
2425

2526

2627
# TODO: Add multi-image in single chat test
27-
@pytest.mark.parametrize("model_dir", [_LLAVA_DIR, _QWEN_2_5_VL_DIR])
28+
@pytest.mark.parametrize("model_dir",
29+
[_LLAVA_DIR, _QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR])
2830
@pytest.mark.parametrize("pd_disagg", [False, True])
2931
def test_single_image_chat(model_dir, pd_disagg):
3032
"""Test processing single image using encoder (pass mm_embeddings) + LLM API.
@@ -180,6 +182,7 @@ def test_single_image_chat(model_dir, pd_disagg):
180182
# Qwen2.5 VL's vision encoder seems to output different embeddings based on this value.
181183
# The test only passes with this set to 1.
182184
(_QWEN_2_5_VL_DIR, 1),
185+
(_QWEN_3_VL_DIR, 3),
183186
],
184187
)
185188
def test_multi_request_batch_chat(model_dir, encoder_max_batch_size):

0 commit comments

Comments
 (0)