Skip to content

Commit 73870ae

Browse files
authored
[None][feat] support Qwen3-VL dense model in pytorch backend (#9060)
Signed-off-by: Nekofish-L <[email protected]>
1 parent 827d12c commit 73870ae

File tree

6 files changed

+360
-23
lines changed

6 files changed

+360
-23
lines changed

tensorrt_llm/_torch/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .modeling_qwen3 import Qwen3ForCausalLM
2929
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
3030
from .modeling_qwen3_next import Qwen3NextForCausalLM
31+
from .modeling_qwen3vl import Qwen3VLModel
3132
from .modeling_qwen3vl_moe import Qwen3MoeVLModel
3233
from .modeling_qwen_moe import Qwen2MoeForCausalLM
3334
from .modeling_seedoss import SeedOssForCausalLM
@@ -76,6 +77,7 @@
7677
"GptOssForCausalLM",
7778
"SeedOssForCausalLM",
7879
"Glm4MoeForCausalLM",
80+
"Qwen3VLModel",
7981
]
8082

8183
if transformers.__version__ >= "4.45.1":

tensorrt_llm/_torch/models/checkpoints/__init__.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .hf.qwen2vl_weight_mapper import Qwen2VLHfWeightMapper
1111
from .hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper
1212
from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper
13+
from .hf.qwen3vl_weight_mapper import Qwen3VLHfWeightMapper
1314
from .hf.weight_loader import HfWeightLoader
1415
from .hf.weight_mapper import HfWeightMapper
1516
from .mistral.checkpoint_loader import (MistralCheckpointLoader,
@@ -19,23 +20,12 @@
1920
MistralWeightMapper)
2021

2122
__all__ = [
22-
"HfConfigLoader",
23-
"HfWeightLoader",
24-
"HfWeightMapper",
25-
"MistralConfigLoader",
26-
"MistralWeightMapper",
27-
"MistralCheckpointLoader",
28-
"BaseCheckpointLoader",
29-
"HfCheckpointLoader",
30-
"NemotronHHfWeightMapper",
31-
"Gemma3HfWeightMapper",
32-
"MixtralHfWeightMapper",
33-
"Llama4HfWeightMapper",
34-
"Qwen2MoeHfWeightMapper",
35-
"Qwen3MoeHfWeightMapper",
36-
"Qwen2VLHfWeightMapper",
37-
"Qwen3NextHfWeightMapper",
38-
"LlavaNextHfWeightMapper",
39-
"MistralLarge3CheckpointLoader",
40-
"MistralLarge3WeightMapper",
23+
"HfConfigLoader", "HfWeightLoader", "HfWeightMapper", "MistralConfigLoader",
24+
"MistralWeightMapper", "MistralCheckpointLoader", "BaseCheckpointLoader",
25+
"HfCheckpointLoader", "NemotronHHfWeightMapper", "Gemma3HfWeightMapper",
26+
"MixtralHfWeightMapper", "Llama4HfWeightMapper", "Qwen2MoeHfWeightMapper",
27+
"Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper",
28+
"Qwen3NextHfWeightMapper", "LlavaNextHfWeightMapper",
29+
"MistralLarge3CheckpointLoader", "MistralLarge3WeightMapper",
30+
"Qwen3VLHfWeightMapper"
4131
]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper
2+
from tensorrt_llm._torch.models.modeling_utils import register_mapper
3+
4+
5+
@register_mapper("HF", "Qwen3VLForConditionalGeneration")
6+
class Qwen3VLHfWeightMapper(HfWeightMapper):
7+
def preprocess_weights(self, weights: dict) -> dict:
8+
return weights

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def forward(
121121
attn_metadata: AttentionMetadata,
122122
residual: Optional[torch.Tensor],
123123
spec_metadata: Optional[SpecMetadata] = None,
124+
mrope_config: Optional[dict] = None,
125+
deepstack_embeds: Optional[list[torch.Tensor]] = None,
124126
**kwargs,
125127
) -> torch.Tensor:
126128
if residual is None:
@@ -137,6 +139,7 @@ def forward(
137139
attn_metadata=attn_metadata,
138140
all_reduce_params=AllReduceParams(
139141
enable_allreduce=not self.disable_allreduce),
142+
mrope_config=mrope_config,
140143
**kwargs,
141144
)
142145

@@ -150,6 +153,9 @@ def forward(
150153
enable_allreduce=not self.disable_allreduce),
151154
cutlass_min_latency_mode=False,
152155
)
156+
if deepstack_embeds is not None and self.layer_idx in range(
157+
len(deepstack_embeds)):
158+
residual = residual + deepstack_embeds[self.layer_idx]
153159

154160
if spec_metadata is not None:
155161
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
@@ -191,6 +197,9 @@ def forward(
191197
position_ids: Optional[torch.IntTensor] = None,
192198
inputs_embeds: Optional[torch.FloatTensor] = None,
193199
spec_metadata: Optional[SpecMetadata] = None,
200+
mrope_config: Optional[dict] = None,
201+
# args for deepstack
202+
deepstack_embeds: Optional[list[torch.Tensor]] = None,
194203
**kwargs,
195204
) -> torch.Tensor:
196205
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -211,8 +220,8 @@ def forward(
211220
attn_metadata=attn_metadata,
212221
residual=residual,
213222
spec_metadata=spec_metadata,
214-
)
215-
223+
mrope_config=mrope_config,
224+
deepstack_embeds=deepstack_embeds)
216225
hidden_states, _ = self.norm(hidden_states, residual)
217226
return hidden_states
218227

tensorrt_llm/_torch/models/modeling_qwen3vl.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
BaseMultimodalDummyInputsBuilder,
2222
BaseMultimodalInputProcessor,
2323
ExtraProcessedInputs,
24+
MultimodalPlaceholderMetadata,
25+
MultimodalPlaceholderPlacement,
2426
TextPrompt,
27+
register_input_processor,
2528
)
2629
from ...inputs.multimodal import MultimodalParams
2730
from ...logger import logger
@@ -33,14 +36,23 @@
3336
from ..modules.linear import Linear, TensorParallelMode
3437
from ..modules.mlp import MLP
3538
from ..modules.rotary_embedding import MRotaryEmbedding
39+
from .checkpoints.base_weight_mapper import BaseWeightMapper
40+
from .checkpoints.hf.qwen3vl_weight_mapper import Qwen3VLHfWeightMapper
3641
from .modeling_auto import AutoModelForCausalLM
3742
from .modeling_multimodal_utils import (
3843
find_input_mm_embeds,
3944
fuse_input_embeds,
4045
get_multimodal_embeddings,
4146
)
4247
from .modeling_qwen2vl import Qwen2_5_VLVisionAttention
43-
from .modeling_utils import ModelConfig, QuantConfig, _load_weights_impl, filter_weights
48+
from .modeling_utils import (
49+
ModelConfig,
50+
QuantConfig,
51+
_load_weights_impl,
52+
filter_weights,
53+
register_auto_model,
54+
register_vision_encoder,
55+
)
4456

4557

4658
class Qwen3VLInputProcessorBase(BaseMultimodalInputProcessor, BaseMultimodalDummyInputsBuilder):
@@ -807,7 +819,12 @@ def __init__(
807819

808820
llm_model_config = copy.deepcopy(model_config)
809821
llm_model_config.pretrained_config = config.text_config
810-
llm_model_config.pretrained_config.architectures = ["Qwen3MoeForCausalLM"]
822+
if self.original_arch == "Qwen3VLForConditionalGeneration":
823+
llm_model_config.pretrained_config.architectures = ["Qwen3ForCausalLM"]
824+
elif self.original_arch == "Qwen3VLMoeForConditionalGeneration":
825+
llm_model_config.pretrained_config.architectures = ["Qwen3MoeForCausalLM"]
826+
else:
827+
raise ValueError(f"Unsupported architecture: {self.original_arch}")
811828
self.llm = AutoModelForCausalLM.from_config(llm_model_config)
812829

813830
if not _is_disagg():
@@ -990,3 +1007,42 @@ def forward(
9901007
)
9911008
logger.debug(f"output shape: {output_prob.shape}")
9921009
return output_prob
1010+
1011+
1012+
@register_vision_encoder(Qwen3VisionModelBase, vlm_base_model=Qwen3VisionModel)
1013+
@register_auto_model("Qwen3VLForConditionalGeneration")
1014+
@register_input_processor(
1015+
Qwen3VLInputProcessorBase,
1016+
model_type="qwen3_vl",
1017+
placeholder_metadata=MultimodalPlaceholderMetadata(
1018+
placeholder_map={
1019+
"image": "<|vision_start|><|image_pad|><|vision_end|>",
1020+
"video": "<|vision_start|><|video_pad|><|vision_end|>",
1021+
},
1022+
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
1023+
),
1024+
)
1025+
class Qwen3VLModel(Qwen3VLModelBase):
1026+
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args, **kwargs):
1027+
# NOTE: HF implementation.
1028+
kwargs["vision_model_class"] = Qwen3VisionModel
1029+
kwargs["disable_fuse_rope"] = kwargs.get(
1030+
"disable_fuse_rope", False
1031+
) # TODO: Make this ModelConfig's argument
1032+
super().__init__(model_config, *args, **kwargs)
1033+
1034+
@property
1035+
def multimodal_data_device_paths(self) -> List[str]:
1036+
return ["image.pixel_values", "video.pixel_values_videos", "multimodal_embedding"]
1037+
1038+
def load_weights(self, weights: Dict[str, torch.Tensor], weight_mapper: BaseWeightMapper):
1039+
if not _is_disagg():
1040+
self.mm_encoder.load_weights(weights)
1041+
1042+
weight_mapper = Qwen3VLHfWeightMapper()
1043+
weight_mapper.init_model_and_config(self.llm, self.model_config)
1044+
filtered_weights = {k: v for k, v in weights.items() if not k.startswith("model.visual.")}
1045+
params_map = {
1046+
r"^model\.language_model\.(.*)$": r"model.\1",
1047+
}
1048+
self.llm.load_weights(filtered_weights, weight_mapper, params_map=params_map)

0 commit comments

Comments
 (0)