Skip to content

Commit 7295af6

Browse files
[None][fix] Enable AttentionDP on Qwen3-VL and fix test (#10435)
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
1 parent 1c69aad commit 7295af6

File tree

7 files changed

+90
-17
lines changed

7 files changed

+90
-17
lines changed

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,10 @@ def forward(self, multimodal_params: List[MultimodalParams]):
502502

503503
class Qwen2_5_VLVisionAttention(Attention):
504504

505-
def __init__(self, model_config: ModelConfig[PretrainedConfig],
506-
layer_idx: int) -> None:
505+
def __init__(self,
506+
model_config: ModelConfig[PretrainedConfig],
507+
layer_idx: int,
508+
reduce_output: bool = True) -> None:
507509

508510
config = model_config.pretrained_config.vision_config
509511
super().__init__(
@@ -518,6 +520,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
518520
layer_idx=layer_idx,
519521
dtype=config.torch_dtype,
520522
config=model_config,
523+
reduce_output=reduce_output,
521524
)
522525

523526
def forward(

tensorrt_llm/_torch/models/modeling_qwen3vl.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from tensorrt_llm._torch.models.modeling_multimodal_utils import _is_disagg
1717
from tensorrt_llm.functional import PositionEmbeddingType
18+
from tensorrt_llm.mapping import Mapping
1819

1920
from ..._utils import nvtx_range, nvtx_range_debug
2021
from ...inputs import (
@@ -439,7 +440,13 @@ def __init__(self, model_config, layer_idx):
439440
model_config.pretrained_config.vision_config.torch_dtype = (
440441
model_config.pretrained_config.text_config.dtype
441442
)
442-
super().__init__(model_config, layer_idx)
443+
super().__init__(
444+
model_config,
445+
layer_idx=layer_idx,
446+
reduce_output=(
447+
not model_config.mapping.enable_attention_dp and model_config.mapping.tp_size > 1
448+
),
449+
)
443450

444451

445452
class Qwen3VLVisionMLP(MLP):
@@ -453,12 +460,14 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int):
453460
dtype=model_config.pretrained_config.text_config.dtype,
454461
config=model_config,
455462
layer_idx=layer_idx,
463+
overridden_tp_size=1 if model_config.mapping.enable_attention_dp else None,
456464
)
457465

458466

459467
class Qwen3VLVisionBlock(torch.nn.Module):
460468
def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int):
461469
super().__init__()
470+
self.model_config = model_config
462471
config = model_config.pretrained_config.vision_config
463472

464473
self.norm1 = LayerNorm(
@@ -510,11 +519,29 @@ def __init__(
510519
eps=model_config.pretrained_config.text_config.rms_norm_eps,
511520
dtype=model_config.pretrained_config.text_config.dtype,
512521
)
522+
523+
self.mapping = model_config.mapping
524+
overridden_tp_size = 1 if model_config.mapping.enable_attention_dp else None
525+
if overridden_tp_size is not None:
526+
assert self.mapping.tp_size % overridden_tp_size == 0
527+
tp_size = overridden_tp_size
528+
# "Misuse" pp_size here to perform all-reduce within smaller groups
529+
pp_size = self.mapping.pp_size * self.mapping.tp_size // overridden_tp_size
530+
mapping = Mapping(
531+
world_size=tp_size * pp_size,
532+
rank=self.mapping.rank,
533+
gpus_per_node=self.mapping.gpus_per_node,
534+
tp_size=tp_size,
535+
pp_size=pp_size,
536+
)
537+
else:
538+
mapping = self.mapping
539+
513540
self.linear_fc1 = Linear(
514541
in_features=self.hidden_size,
515542
out_features=self.hidden_size,
516543
bias=True,
517-
mapping=model_config.mapping,
544+
mapping=mapping,
518545
tensor_parallel_mode=TensorParallelMode.COLUMN,
519546
allreduce_strategy=model_config.allreduce_strategy,
520547
)
@@ -523,7 +550,7 @@ def __init__(
523550
in_features=self.hidden_size,
524551
out_features=config.out_hidden_size,
525552
bias=True,
526-
mapping=model_config.mapping,
553+
mapping=mapping,
527554
tensor_parallel_mode=TensorParallelMode.ROW,
528555
allreduce_strategy=model_config.allreduce_strategy,
529556
)
@@ -705,16 +732,16 @@ def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata):
705732

706733
@torch.inference_mode()
707734
def forward(
708-
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs
709-
) -> torch.Tensor:
735+
self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs
736+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
710737
seq_lens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).tolist()
711738
attn_metadata = self.prepare_attn_metadata(seq_lens, self.attn_metadata)
712739

713740
# Getting positional embedding
714741
rotary_pos_emb = self.rot_pos_emb(grid_thw)
715742

716743
# From this point, pure GPU operation
717-
hidden_states = self.patch_embed(hidden_states)
744+
hidden_states = self.patch_embed(pixel_values)
718745
seq_len, _ = hidden_states.size()
719746
hidden_states = hidden_states.reshape(seq_len, -1)
720747

tensorrt_llm/_torch/modules/mlp.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
from torch import nn
66

7+
from tensorrt_llm.mapping import Mapping
8+
79
from ..model_config import ModelConfig
810
from ..peft.lora.layer import LoraLayer, LoraModuleType
911
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
@@ -20,7 +22,8 @@ def __init__(self,
2022
dtype: Optional[torch.dtype] = None,
2123
config: Optional[ModelConfig] = None,
2224
layer_idx: Optional[int] = None,
23-
reduce_output: bool = True):
25+
reduce_output: bool = True,
26+
overridden_tp_size: Optional[int] = None):
2427

2528
super().__init__()
2629
self.layer_idx = layer_idx
@@ -29,6 +32,22 @@ def __init__(self,
2932
self.activation = activation
3033

3134
config = config or ModelConfig()
35+
self.mapping = config.mapping
36+
if overridden_tp_size is not None:
37+
assert config.mapping.tp_size % overridden_tp_size == 0
38+
tp_size = overridden_tp_size
39+
# "Misuse" pp_size here to perform all-reduce within smaller groups
40+
pp_size = config.mapping.pp_size * config.mapping.tp_size // overridden_tp_size
41+
mapping = Mapping(
42+
world_size=tp_size * pp_size,
43+
rank=self.mapping.rank,
44+
gpus_per_node=self.mapping.gpus_per_node,
45+
tp_size=tp_size,
46+
pp_size=pp_size,
47+
)
48+
else:
49+
mapping = config.mapping
50+
3251
self.up_lora = LoraLayer(
3352
[LoraModuleType.MLP_H_TO_4H],
3453
[self.intermediate_size // config.mapping.tp_size])
@@ -38,7 +57,7 @@ def __init__(self,
3857
self.intermediate_size,
3958
bias=bias,
4059
dtype=dtype,
41-
mapping=config.mapping,
60+
mapping=mapping,
4261
tensor_parallel_mode=TensorParallelMode.COLUMN,
4362
weights_loading_config=WeightsLoadingConfig(
4463
weight_mode=WeightMode.VANILLA),
@@ -55,7 +74,7 @@ def __init__(self,
5574
self.hidden_size,
5675
bias=bias,
5776
dtype=dtype,
58-
mapping=config.mapping,
77+
mapping=mapping,
5978
tensor_parallel_mode=TensorParallelMode.ROW,
6079
quant_config=config.get_quant_config(),
6180
skip_create_weights_in_init=config.skip_create_weights_in_init,

tests/integration/defs/accuracy/references/mmmu.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@ Qwen/Qwen3-VL-30B-A3B-Instruct:
2727
mistral/Mistral-Large-3-675B:
2828
# Mistral Large 3 675B only supports single image input, so accuracy is lower.
2929
- accuracy: 47
30+
Qwen/Qwen3-VL-8B-Instruct:
31+
- accuracy: 55.11

tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,21 @@ def test_nvfp4_4gpus(
327327
) as llm:
328328
task = MMMU(self.MODEL_NAME)
329329
task.evaluate(llm, sampling_params=self.sampling_params)
330+
331+
332+
class TestQwen3VL(LlmapiAccuracyTestHarness):
333+
MODEL_NAME = "Qwen/Qwen3-VL-8B-Instruct"
334+
MODEL_PATH = f"{llm_models_root()}/Qwen3/Qwen3-VL-8B-Instruct"
335+
MAX_NUM_TOKENS = 16384
336+
337+
sampling_params = SamplingParams(
338+
max_tokens=MAX_NUM_TOKENS, truncate_prompt_tokens=MMMU.MAX_INPUT_LEN, stop="<|endoftext|>"
339+
)
340+
341+
def test_auto_dtype(self):
342+
with LLM(
343+
self.MODEL_PATH,
344+
max_num_tokens=self.MAX_NUM_TOKENS,
345+
) as llm:
346+
task = MMMU(self.MODEL_NAME)
347+
task.evaluate(llm, sampling_params=self.sampling_params)

tests/integration/test_lists/test-db/l0_l40s.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ l0_l40s:
1414
backend: pytorch
1515
tests:
1616
# ------------- PyTorch tests ---------------
17+
# Multimodal modeling tests
1718
- unittest/_torch/modeling -k "modeling_mllama"
1819
- unittest/_torch/modeling -k "modeling_siglip"
1920
- unittest/_torch/modeling -k "modeling_vila"
@@ -22,6 +23,7 @@ l0_l40s:
2223
- unittest/_torch/modeling/test_modeling_llava_next.py::TestLlavaNext::test_all
2324
- unittest/_torch/modeling/test_modeling_qwen2_5vl.py::TestQwen2_5_VL::test_all
2425
- unittest/_torch/modeling/test_modeling_qwen3vl_moe.py::TestQwen3VLMoe::test_all
26+
- unittest/_torch/modeling/test_modeling_qwen3vl.py::TestQwen3VL::test_all
2527
- test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
2628
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-audio]
2729
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-image]

tests/unittest/_torch/modeling/test_modeling_qwen3vl.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,6 @@ def get_scenarios(self) -> List[TestQwen3VLScenario]:
237237
chunked_prefill=False,
238238
kv_cache_reuse=False,
239239
),
240-
# ==== Disable fuse rope scenarios ====
241-
# TestQwen3VLScenario(modality="image",
242-
# use_cuda_graph=False,
243-
# disable_fuse_rope=True,
244-
# chunked_prefill=False,
245-
# kv_cache_reuse=False),
246240
# ==== Chunked Prefill Scenarios ====
247241
TestQwen3VLScenario(
248242
modality="image",
@@ -259,6 +253,14 @@ def get_scenarios(self) -> List[TestQwen3VLScenario]:
259253
chunked_prefill=False,
260254
kv_cache_reuse=True,
261255
),
256+
# ==== Disable fuse rope scenarios ====
257+
TestQwen3VLScenario(
258+
modality="image",
259+
use_cuda_graph=False,
260+
disable_fuse_rope=True,
261+
chunked_prefill=False,
262+
kv_cache_reuse=False,
263+
),
262264
]
263265
return scenarios
264266

0 commit comments

Comments
 (0)