diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index d421b31de50..4ed21f34bee 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -502,8 +502,10 @@ def forward(self, multimodal_params: List[MultimodalParams]): class Qwen2_5_VLVisionAttention(Attention): - def __init__(self, model_config: ModelConfig[PretrainedConfig], - layer_idx: int) -> None: + def __init__(self, + model_config: ModelConfig[PretrainedConfig], + layer_idx: int, + reduce_output: bool = True) -> None: config = model_config.pretrained_config.vision_config super().__init__( @@ -518,6 +520,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx=layer_idx, dtype=config.torch_dtype, config=model_config, + reduce_output=reduce_output, ) def forward( diff --git a/tensorrt_llm/_torch/models/modeling_qwen3vl.py b/tensorrt_llm/_torch/models/modeling_qwen3vl.py index d073f6745b7..4bdcfc76784 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3vl.py @@ -15,6 +15,7 @@ from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.mapping import Mapping from ..._utils import nvtx_range, nvtx_range_debug from ...inputs import ( @@ -439,7 +440,13 @@ def __init__(self, model_config, layer_idx): model_config.pretrained_config.vision_config.torch_dtype = ( model_config.pretrained_config.text_config.dtype ) - super().__init__(model_config, layer_idx) + super().__init__( + model_config, + layer_idx=layer_idx, + reduce_output=( + not model_config.mapping.enable_attention_dp and model_config.mapping.tp_size > 1 + ), + ) class Qwen3VLVisionMLP(MLP): @@ -453,12 +460,14 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int): dtype=model_config.pretrained_config.text_config.dtype, config=model_config, layer_idx=layer_idx, + overridden_tp_size=1 if model_config.mapping.enable_attention_dp else None, ) class Qwen3VLVisionBlock(torch.nn.Module): def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int): super().__init__() + self.model_config = model_config config = model_config.pretrained_config.vision_config self.norm1 = LayerNorm( @@ -510,11 +519,29 @@ def __init__( eps=model_config.pretrained_config.text_config.rms_norm_eps, dtype=model_config.pretrained_config.text_config.dtype, ) + + self.mapping = model_config.mapping + overridden_tp_size = 1 if model_config.mapping.enable_attention_dp else None + if overridden_tp_size is not None: + assert self.mapping.tp_size % overridden_tp_size == 0 + tp_size = overridden_tp_size + # "Misuse" pp_size here to perform all-reduce within smaller groups + pp_size = self.mapping.pp_size * self.mapping.tp_size // overridden_tp_size + mapping = Mapping( + world_size=tp_size * pp_size, + rank=self.mapping.rank, + gpus_per_node=self.mapping.gpus_per_node, + tp_size=tp_size, + pp_size=pp_size, + ) + else: + mapping = self.mapping + self.linear_fc1 = Linear( in_features=self.hidden_size, out_features=self.hidden_size, bias=True, - mapping=model_config.mapping, + mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, allreduce_strategy=model_config.allreduce_strategy, ) @@ -523,7 +550,7 @@ def __init__( in_features=self.hidden_size, out_features=config.out_hidden_size, bias=True, - mapping=model_config.mapping, + mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, allreduce_strategy=model_config.allreduce_strategy, ) @@ -705,8 +732,8 @@ def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata): @torch.inference_mode() def forward( - self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs - ) -> torch.Tensor: + self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: seq_lens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).tolist() attn_metadata = self.prepare_attn_metadata(seq_lens, self.attn_metadata) @@ -714,7 +741,7 @@ def forward( rotary_pos_emb = self.rot_pos_emb(grid_thw) # From this point, pure GPU operation - hidden_states = self.patch_embed(hidden_states) + hidden_states = self.patch_embed(pixel_values) seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape(seq_len, -1) diff --git a/tensorrt_llm/_torch/modules/mlp.py b/tensorrt_llm/_torch/modules/mlp.py index 49fba9aeffc..d121457b487 100644 --- a/tensorrt_llm/_torch/modules/mlp.py +++ b/tensorrt_llm/_torch/modules/mlp.py @@ -4,6 +4,8 @@ import torch from torch import nn +from tensorrt_llm.mapping import Mapping + from ..model_config import ModelConfig from ..peft.lora.layer import LoraLayer, LoraModuleType from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig @@ -20,7 +22,8 @@ def __init__(self, dtype: Optional[torch.dtype] = None, config: Optional[ModelConfig] = None, layer_idx: Optional[int] = None, - reduce_output: bool = True): + reduce_output: bool = True, + overridden_tp_size: Optional[int] = None): super().__init__() self.layer_idx = layer_idx @@ -29,6 +32,22 @@ def __init__(self, self.activation = activation config = config or ModelConfig() + self.mapping = config.mapping + if overridden_tp_size is not None: + assert config.mapping.tp_size % overridden_tp_size == 0 + tp_size = overridden_tp_size + # "Misuse" pp_size here to perform all-reduce within smaller groups + pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size + mapping = Mapping( + world_size=tp_size * pp_size, + rank=self.mapping.rank, + gpus_per_node=self.mapping.gpus_per_node, + tp_size=tp_size, + pp_size=pp_size, + ) + else: + mapping = config.mapping + self.up_lora = LoraLayer( [LoraModuleType.MLP_H_TO_4H], [self.intermediate_size // config.mapping.tp_size]) @@ -38,7 +57,7 @@ def __init__(self, self.intermediate_size, bias=bias, dtype=dtype, - mapping=config.mapping, + mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, weights_loading_config=WeightsLoadingConfig( weight_mode=WeightMode.VANILLA), @@ -55,7 +74,7 @@ def __init__(self, self.hidden_size, bias=bias, dtype=dtype, - mapping=config.mapping, + mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, diff --git a/tests/integration/defs/accuracy/references/mmmu.yaml b/tests/integration/defs/accuracy/references/mmmu.yaml index 37819c3f14b..c612e0182cb 100644 --- a/tests/integration/defs/accuracy/references/mmmu.yaml +++ b/tests/integration/defs/accuracy/references/mmmu.yaml @@ -27,3 +27,5 @@ Qwen/Qwen3-VL-30B-A3B-Instruct: mistral/Mistral-Large-3-675B: # Mistral Large 3 675B only supports single image input, so accuracy is lower. - accuracy: 47 +Qwen/Qwen3-VL-8B-Instruct: + - accuracy: 55.11 diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py b/tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py index c3a812b195b..6c657f95212 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py @@ -327,3 +327,21 @@ def test_nvfp4_4gpus( ) as llm: task = MMMU(self.MODEL_NAME) task.evaluate(llm, sampling_params=self.sampling_params) + + +class TestQwen3VL(LlmapiAccuracyTestHarness): + MODEL_NAME = "Qwen/Qwen3-VL-8B-Instruct" + MODEL_PATH = f"{llm_models_root()}/Qwen3/Qwen3-VL-8B-Instruct" + MAX_NUM_TOKENS = 16384 + + sampling_params = SamplingParams( + max_tokens=MAX_NUM_TOKENS, truncate_prompt_tokens=MMMU.MAX_INPUT_LEN, stop="<|endoftext|>" + ) + + def test_auto_dtype(self): + with LLM( + self.MODEL_PATH, + max_num_tokens=self.MAX_NUM_TOKENS, + ) as llm: + task = MMMU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=self.sampling_params) diff --git a/tests/integration/test_lists/test-db/l0_l40s.yml b/tests/integration/test_lists/test-db/l0_l40s.yml index 68a46e81127..c3037894895 100644 --- a/tests/integration/test_lists/test-db/l0_l40s.yml +++ b/tests/integration/test_lists/test-db/l0_l40s.yml @@ -14,6 +14,7 @@ l0_l40s: backend: pytorch tests: # ------------- PyTorch tests --------------- + # Multimodal modeling tests - unittest/_torch/modeling -k "modeling_mllama" - unittest/_torch/modeling -k "modeling_siglip" - unittest/_torch/modeling -k "modeling_vila" @@ -22,6 +23,7 @@ l0_l40s: - unittest/_torch/modeling/test_modeling_llava_next.py::TestLlavaNext::test_all - unittest/_torch/modeling/test_modeling_qwen2_5vl.py::TestQwen2_5_VL::test_all - unittest/_torch/modeling/test_modeling_qwen3vl_moe.py::TestQwen3VLMoe::test_all + - unittest/_torch/modeling/test_modeling_qwen3vl.py::TestQwen3VL::test_all - test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] - test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-audio] - test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image] diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen3vl.py b/tests/unittest/_torch/modeling/test_modeling_qwen3vl.py index 52ba53a5b06..35fb5983606 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen3vl.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen3vl.py @@ -237,12 +237,6 @@ def get_scenarios(self) -> List[TestQwen3VLScenario]: chunked_prefill=False, kv_cache_reuse=False, ), - # ==== Disable fuse rope scenarios ==== - # TestQwen3VLScenario(modality="image", - # use_cuda_graph=False, - # disable_fuse_rope=True, - # chunked_prefill=False, - # kv_cache_reuse=False), # ==== Chunked Prefill Scenarios ==== TestQwen3VLScenario( modality="image", @@ -259,6 +253,14 @@ def get_scenarios(self) -> List[TestQwen3VLScenario]: chunked_prefill=False, kv_cache_reuse=True, ), + # ==== Disable fuse rope scenarios ==== + TestQwen3VLScenario( + modality="image", + use_cuda_graph=False, + disable_fuse_rope=True, + chunked_prefill=False, + kv_cache_reuse=False, + ), ] return scenarios