Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ def init_model_and_config(self, model: Union[nn.Module,
raise ValueError("model must have a config attribute")

self._tp_size = 1 if model.model_config.mapping.enable_attention_dp else model.model_config.mapping.tp_size
self._head_dim = model.config.head_dim if hasattr(
model.config, 'head_dim'
) and model.config.head_dim is not None else model.config.hidden_size // model.config.num_attention_heads

self.map_weights()

Expand Down Expand Up @@ -173,3 +170,11 @@ def model(self) -> Union[nn.Module, DecoderModelForCausalLM]:
if self._model is None:
raise RuntimeError("Weight mapper is not initialized")
return self._model

@property
def _head_dim(self) -> int:
model = self.model
head_dim = model.config.head_dim if hasattr(
model.config, 'head_dim'
) and model.config.head_dim is not None else model.config.hidden_size // model.config.num_attention_heads
return head_dim
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
Qwen3VLTextConfig,
Qwen3VLVisionConfig,
)

from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper
from tensorrt_llm._torch.models.modeling_utils import register_mapper

Expand All @@ -6,3 +11,17 @@
class Qwen3VLHfWeightMapper(HfWeightMapper):
def preprocess_weights(self, weights: dict) -> dict:
return weights

@property
def _head_dim(self) -> int:
config = self.model.config
if (head_dim := getattr(config, "head_dim", None)) is not None:
return head_dim
if isinstance(config, Qwen3VLTextConfig):
num_heads = config.num_attention_heads
elif isinstance(config, Qwen3VLVisionConfig):
num_heads = config.num_heads
else:
raise TypeError(f"Unexpected config class {type(config).__name__}.")

return config.hidden_size // num_heads
110 changes: 101 additions & 9 deletions tensorrt_llm/_torch/models/modeling_qwen3vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MultimodalPlaceholderPlacement,
TextPrompt,
register_input_processor,
support_multimodal_disaggregated,
)
from ...inputs.multimodal import MultimodalParams
from ...logger import logger
Expand Down Expand Up @@ -350,6 +351,85 @@ def __call__(
"multimodal_data": multimodal_data,
}

def get_prompt_token_ids(
self, inputs: TextPrompt, mm_handles: List[Dict[str, Any]]
) -> Tuple[List[int], List[int], List[int]]:
"""
Build input token ids with multimodal placeholders expanded to the number of MM tokens.

Args:
inputs: Text prompt input container. Must contain a non-empty prompt string.
mm_handles: List of multimodal embedding handles. Currently only a single handle is supported.

Returns:
Tuple[List[int], List[int], List[int]]:
- expanded_ids: token ids with each image token expanded to a placeholder repeated per MM token
- mm_token_length: per-image MM token lengths
- mm_token_offsets: start offsets (positions) for each image's MM tokens within expanded_ids
"""
# TODO: Move this function to the base input processor class when extending for more models
text_prompt = inputs.get("prompt")
if not text_prompt:
raise ValueError("Text prompt is required but not provided")

if not isinstance(mm_handles, list):
raise TypeError("mm_handles must be a list")

if len(mm_handles) > 1:
# TODO: only support single multimodal item within a request for now
raise NotImplementedError("Only one mm_handle is supported for Qwen3 VL for now")

hidden_size = mm_handles[0]["tensor_size"][1]
num_deepstack_levels = len(self.config.vision_config.deepstack_visual_indexes)
# This is because, unlike previous Qwen VL models, the embeddings are concatenated with
# feature maps from deepstack layers.
expected_size = self.config.text_config.hidden_size * (1 + num_deepstack_levels)
if hidden_size != expected_size:
raise RuntimeError(
f"Expected multimodal embedding to have hidden size {expected_size}, got {hidden_size}."
)

input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids[0]

# TODO: what about `video_token_id`?
image_token_index = self.config.image_token_id

image_mask = input_ids == image_token_index
image_positions = torch.where(image_mask)[0]
num_images = len(image_positions)
assert num_images == len(mm_handles), "Number of images must match number of mm_handles"
total_mm_tokens = sum(mm_handle["tensor_size"][0] for mm_handle in mm_handles)
final_length = len(input_ids) - num_images + total_mm_tokens
# Create output tensor
expanded_ids = torch.empty(final_length, dtype=input_ids.dtype)
placeholder_id = self.tllm_multimodal_token_id

# Fill the expanded sequence
write_pos = 0
image_cnt = 0
mm_token_length = []
mm_token_offsets = []
for read_pos in range(len(input_ids)):
if input_ids[read_pos] == image_token_index:
# Replace with placeholder id
mm_token_num = mm_handles[image_cnt]["tensor_size"][0]
expanded_ids[write_pos : write_pos + mm_token_num] = placeholder_id
mm_token_offsets.append(write_pos)
mm_token_length.append(mm_token_num)
write_pos += mm_token_num
image_cnt += 1
else:
# Copy text token as-is
expanded_ids[write_pos] = input_ids[read_pos]
write_pos += 1

assert write_pos == final_length, f"Write position mismatch: {write_pos} != {final_length}"
assert mm_token_length[-1] + mm_token_offsets[-1] <= final_length, (
f"mm_token_length[-1] + mm_token_offsets[-1] ({mm_token_length[-1] + mm_token_offsets[-1]}) should be less "
f"than or equal to final_length ({final_length})"
)
return expanded_ids.to(torch.int32).tolist(), mm_token_length, mm_token_offsets


class Qwen3VLVisionAttention(Qwen2_5_VLVisionAttention):
def __init__(self, model_config, layer_idx):
Expand Down Expand Up @@ -825,6 +905,7 @@ def __init__(
llm_model_config.pretrained_config.architectures = ["Qwen3MoeForCausalLM"]
else:
raise ValueError(f"Unsupported architecture: {self.original_arch}")
# Qwen3ForCausalLM.
self.llm = AutoModelForCausalLM.from_config(llm_model_config)

if not _is_disagg():
Expand Down Expand Up @@ -953,22 +1034,16 @@ def forward(

# NOTE: Qwen*-VL series has mrope_config even on the text-only prompts,
# so we need to separate the mm_multimodal_params from the text-only prompts.
mm_multimodal_params = [
multimodal_param
for multimodal_param in multimodal_params
if multimodal_param.multimodal_data.get("image", {}).get("pixel_values") is not None
or multimodal_param.multimodal_data.get("video", {}).get("pixel_values_videos")
is not None
]
mm_multimodal_params = self._get_requests_with_mm_data(multimodal_params)
if len(mm_multimodal_params) > 0:
if not _is_disagg():
mm_embeds = get_multimodal_embeddings(
encoder_forward_fn=self.mm_encoder.forward,
multimodal_params=mm_multimodal_params,
)
else:
elif not getattr(self, "support_mm_disagg", False):
raise NotImplementedError(
"Qwen3VLModel does not support disaggregated inference yet. Please unset "
f"{type(self)} does not support disaggregated inference yet. Please unset "
"the TLLM_MULTIMODAL_DISAGGREGATED environment variable, or set it to '0'."
)
mm_embeds = find_input_mm_embeds(mm_embeds, mm_multimodal_params)
Expand Down Expand Up @@ -1008,7 +1083,24 @@ def forward(
logger.debug(f"output shape: {output_prob.shape}")
return output_prob

def _get_requests_with_mm_data(self, multimodal_params):
mm_multimodal_params = []
for multimodal_param in multimodal_params:
data = multimodal_param.multimodal_data
if (
# The first 2 conditions check whether there is input on which inference should be run.
data.get("image", {}).get("pixel_values") is not None
or data.get("video", {}).get("pixel_values_videos") is not None
# This condition corresponds to when the embeddings are already populated, as is e.g.
# the case in EPD disagg in the prefill worker.
or data.get("multimodal_embedding")
):
mm_multimodal_params.append(multimodal_param)

return mm_multimodal_params


@support_multimodal_disaggregated
@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel)
@register_auto_model("Qwen3VLForConditionalGeneration")
@register_input_processor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

_LLAVA_DIR = llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf"
_QWEN_2_5_VL_DIR = llm_models_root() / "Qwen2.5-VL-3B-Instruct"
_QWEN_3_VL_DIR = llm_models_root() / "Qwen3" / "Qwen3-VL-2B-Instruct"


# TODO: Add multi-image in single chat test
@pytest.mark.parametrize("model_dir", [_LLAVA_DIR, _QWEN_2_5_VL_DIR])
@pytest.mark.parametrize("model_dir",
[_LLAVA_DIR, _QWEN_2_5_VL_DIR, _QWEN_3_VL_DIR])
@pytest.mark.parametrize("pd_disagg", [False, True])
def test_single_image_chat(model_dir, pd_disagg):
"""Test processing single image using encoder (pass mm_embeddings) + LLM API.
Expand Down Expand Up @@ -180,6 +182,7 @@ def test_single_image_chat(model_dir, pd_disagg):
# Qwen2.5 VL's vision encoder seems to output different embeddings based on this value.
# The test only passes with this set to 1.
(_QWEN_2_5_VL_DIR, 1),
(_QWEN_3_VL_DIR, 3),
],
)
def test_multi_request_batch_chat(model_dir, encoder_max_batch_size):
Expand Down