1515
1616from tensorrt_llm ._torch .models .modeling_multimodal_utils import _is_disagg
1717from tensorrt_llm .functional import PositionEmbeddingType
18+ from tensorrt_llm .mapping import Mapping
1819
1920from ..._utils import nvtx_range , nvtx_range_debug
2021from ...inputs import (
@@ -439,7 +440,13 @@ def __init__(self, model_config, layer_idx):
439440 model_config .pretrained_config .vision_config .torch_dtype = (
440441 model_config .pretrained_config .text_config .dtype
441442 )
442- super ().__init__ (model_config , layer_idx )
443+ super ().__init__ (
444+ model_config ,
445+ layer_idx = layer_idx ,
446+ reduce_output = (
447+ not model_config .mapping .enable_attention_dp and model_config .mapping .tp_size > 1
448+ ),
449+ )
443450
444451
445452class Qwen3VLVisionMLP (MLP ):
@@ -453,12 +460,14 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], layer_idx: int):
453460 dtype = model_config .pretrained_config .text_config .dtype ,
454461 config = model_config ,
455462 layer_idx = layer_idx ,
463+ overridden_tp_size = 1 if model_config .mapping .enable_attention_dp else None ,
456464 )
457465
458466
459467class Qwen3VLVisionBlock (torch .nn .Module ):
460468 def __init__ (self , model_config : ModelConfig [PretrainedConfig ], layer_idx : int ):
461469 super ().__init__ ()
470+ self .model_config = model_config
462471 config = model_config .pretrained_config .vision_config
463472
464473 self .norm1 = LayerNorm (
@@ -510,11 +519,29 @@ def __init__(
510519 eps = model_config .pretrained_config .text_config .rms_norm_eps ,
511520 dtype = model_config .pretrained_config .text_config .dtype ,
512521 )
522+
523+ self .mapping = model_config .mapping
524+ overridden_tp_size = 1 if model_config .mapping .enable_attention_dp else None
525+ if overridden_tp_size is not None :
526+ assert self .mapping .tp_size % overridden_tp_size == 0
527+ tp_size = overridden_tp_size
528+ # "Misuse" pp_size here to perform all-reduce within smaller groups
529+ pp_size = self .mapping .pp_size * self .mapping .tp_size // overridden_tp_size
530+ mapping = Mapping (
531+ world_size = tp_size * pp_size ,
532+ rank = self .mapping .rank ,
533+ gpus_per_node = self .mapping .gpus_per_node ,
534+ tp_size = tp_size ,
535+ pp_size = pp_size ,
536+ )
537+ else :
538+ mapping = self .mapping
539+
513540 self .linear_fc1 = Linear (
514541 in_features = self .hidden_size ,
515542 out_features = self .hidden_size ,
516543 bias = True ,
517- mapping = model_config . mapping ,
544+ mapping = mapping ,
518545 tensor_parallel_mode = TensorParallelMode .COLUMN ,
519546 allreduce_strategy = model_config .allreduce_strategy ,
520547 )
@@ -523,7 +550,7 @@ def __init__(
523550 in_features = self .hidden_size ,
524551 out_features = config .out_hidden_size ,
525552 bias = True ,
526- mapping = model_config . mapping ,
553+ mapping = mapping ,
527554 tensor_parallel_mode = TensorParallelMode .ROW ,
528555 allreduce_strategy = model_config .allreduce_strategy ,
529556 )
@@ -705,16 +732,16 @@ def prepare_attn_metadata(self, seq_lens, attn_metadata: AttentionMetadata):
705732
706733 @torch .inference_mode ()
707734 def forward (
708- self , hidden_states : torch .Tensor , grid_thw : torch .Tensor , ** kwargs
709- ) -> torch .Tensor :
735+ self , pixel_values : torch .Tensor , grid_thw : torch .Tensor , ** kwargs
736+ ) -> Tuple [ torch .Tensor , List [ torch . Tensor ]] :
710737 seq_lens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]).tolist ()
711738 attn_metadata = self .prepare_attn_metadata (seq_lens , self .attn_metadata )
712739
713740 # Getting positional embedding
714741 rotary_pos_emb = self .rot_pos_emb (grid_thw )
715742
716743 # From this point, pure GPU operation
717- hidden_states = self .patch_embed (hidden_states )
744+ hidden_states = self .patch_embed (pixel_values )
718745 seq_len , _ = hidden_states .size ()
719746 hidden_states = hidden_states .reshape (seq_len , - 1 )
720747
0 commit comments