diff --git a/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py b/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py index 4ea7c4a5237..dfa84ceaeea 100644 --- a/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py +++ b/tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py @@ -11,7 +11,7 @@ class NemotronHHfWeightMapper(HfWeightMapper): def preprocess_weights(self, weights: dict) -> dict: config = self.config.pretrained_config - tp_size = self.config.mapping.tp_size + tp_size = 1 if self.config.mapping.enable_attention_dp else self.config.mapping.tp_size tp_rank = self.config.mapping.tp_rank d_inner = config.mamba_head_dim * config.mamba_num_heads diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index baf79012cc6..6df300cc692 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -32,7 +32,7 @@ from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.fused_moe import MoEWeightLoadingMode, create_moe -from ..modules.linear import Linear +from ..modules.linear import Linear, TensorParallelMode from ..modules.mamba.mamba2_mixer import Mamba2Mixer from ..modules.mlp import MLP from ..modules.multi_stream_utils import maybe_execute_in_parallel @@ -85,8 +85,10 @@ def __init__( self, model_config: ModelConfig[NemotronHConfig], layer_idx: int, + reduce_output: bool = False, ): config = model_config.pretrained_config + super().__init__( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, @@ -97,6 +99,7 @@ def __init__( layer_idx=layer_idx, dtype=config.torch_dtype, config=model_config, + reduce_output=reduce_output, ) def forward( @@ -154,6 +157,7 @@ def __init__( shared_expert_intermediate_size = ( config.moe_shared_expert_intermediate_size * config.n_shared_experts) + self.shared_experts = MLP( hidden_size=config.hidden_size, intermediate_size=shared_expert_intermediate_size, @@ -163,7 +167,8 @@ def __init__( config=model_config, layer_idx=self.layer_idx, reduce_output=False, - ) + overridden_tp_size=1 + if model_config.mapping.enable_attention_dp else None) # Setup MoE gate. self.gate = DeepseekV3Gate( self.hidden_size, @@ -193,11 +198,14 @@ def __init__( activation_type=self.activation_type, ) - # AllReduce for combining shared and routed expert outputs in multi-GPU settings. - self.allreduce = AllReduce( - mapping=model_config.mapping, - strategy=model_config.allreduce_strategy, - ) + if not model_config.mapping.enable_attention_dp: + # AllReduce for combining shared and routed expert outputs in multi-GPU settings. + self.allreduce = AllReduce( + mapping=model_config.mapping, + strategy=model_config.allreduce_strategy, + ) + else: + self.allreduce = None # Setup latent projection layers. # These layers should NOT be TP-sharded to ensure MoE receives @@ -322,7 +330,11 @@ def __init__( elif layer_type == "-": self.mixer = MLPLayer(model_config, layer_idx) elif layer_type == "*": - self.mixer = TransformerLayer(model_config, layer_idx) + self.mixer = TransformerLayer( + model_config, + layer_idx, + reduce_output=not model_config.mapping.enable_attention_dp + and model_config.mapping.tp_size > 1) elif layer_type == "E": self.mixer = NemotronHMOE(model_config, layer_idx=layer_idx, @@ -365,12 +377,24 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]): aux_stream_list[2], } - # calculate embeddings - self.embed_tokens = Embedding( - config.vocab_size, - config.hidden_size, - dtype=config.torch_dtype, - ) + if model_config.mapping.enable_attention_dp: + # When attention_dp is enabled, we cannot do all_reduce since + # the problem size of different ranks are different. + # So, we don't do parallelism here. + self.embed_tokens = Embedding( + config.vocab_size, + config.hidden_size, + dtype=config.torch_dtype, + ) + else: + self.embed_tokens = Embedding( + config.vocab_size, + config.hidden_size, + dtype=config.torch_dtype, + mapping=model_config.mapping, + tensor_parallel_mode=TensorParallelMode.COLUMN, + gather_output=True, + ) # create layers layers = [] diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 5b8d133f296..44ab0e51a4b 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -20,6 +20,7 @@ from torch import nn from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata +from tensorrt_llm.mapping import Mapping from ...attention_backend import AttentionMetadata from ...model_config import ModelConfig @@ -57,8 +58,20 @@ def __init__( config = config or ModelConfig() self.mapping = config.mapping - tp_rank = config.mapping.tp_rank - tp_size = config.mapping.tp_size + + if config.mapping.enable_attention_dp: + self.mapping = Mapping( + world_size=config.mapping.pp_size, + tp_size=1, + pp_size=config.mapping.pp_size, + rank=config.mapping.rank, + gpus_per_node=config.mapping.gpus_per_node, + enable_attention_dp=True, + ) + tp_size = 1 + else: + self.mapping = config.mapping + tp_size = config.mapping.tp_size d_inner = head_dim * nheads d_in_proj = 2 * d_inner + 2 * n_groups * d_state + nheads @@ -80,10 +93,6 @@ def __init__( self.remove_padding = remove_padding self.apply_silu = apply_silu - # tp - self.tp_size = tp_size - self.tp_rank = tp_rank - # paged state parameters self.slot_mapping = None self.is_paged_state = False diff --git a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py index dfc497f18e9..55b20937f9e 100644 --- a/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py @@ -45,7 +45,7 @@ def __init__( self.mamba_ssm_cache_dtype = ssm_cache_dtype # get tp size - tp_size = mapping.tp_size + tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1 # derive mamba parameters for conv and ssm states d_inner = head_dim * num_heads diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 01f586accf2..0ca0842b006 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -5276,6 +5276,7 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness): (4, 4, False, True, True), (4, 1, False, False, True), (4, 4, True, False, True), + (4, 4, True, True, True), (4, 1, True, True, True), (4, 4, False, True, False), (4, 1, False, False, False), @@ -5285,9 +5286,6 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness): ) def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp, overlap_scheduler, cuda_graph): - if attention_dp: - pytest.skip( - "Attention DP is not supported for Nemotron-3-Super yet") kv_cache_config = KvCacheConfig(enable_block_reuse=False, mamba_ssm_cache_dtype="float32") @@ -5314,7 +5312,18 @@ def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp, @skip_pre_blackwell @pytest.mark.skip_less_mpi_world_size(8) - def test_nvfp4_8gpus(self): + @pytest.mark.parametrize( + "attention_dp", + [ + False, + True, + ], + ids=[ + "attention_dp_off", + "attention_dp_on", + ], + ) + def test_nvfp4_8gpus(self, attention_dp): # Use this test to track the best performance config. # The optimized config is still under investigation. # Adding this test as placeholder. @@ -5329,7 +5338,7 @@ def test_nvfp4_8gpus(self): tensor_parallel_size=8, moe_expert_parallel_size=8, pipeline_parallel_size=1, - enable_attention_dp=False, + enable_attention_dp=attention_dp, cuda_graph_config=CudaGraphConfig(max_batch_size=32, enable_padding=True), disable_overlap_scheduler=False, diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 09dc8b22677..8e7873f61c1 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -247,7 +247,9 @@ accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4 accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False] accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True] accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True] -accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus +accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True] +accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on] +accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_off] # multimodal accuracy tests accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index eb7ab1bfc60..b7a31e57d95 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -23,7 +23,7 @@ l0_dgx_b200: - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_mxfp4_mxfp8[enable_configurable_moe-True-8-64-TRTLLM] - unittest/_torch/modules/test_fused_moe.py::test_fused_moe_wfp4a16[enable_configurable_moe-TRTLLM-2880-dtype0] - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-True] - - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True] + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False] @@ -85,7 +85,7 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] TIMEOUT (60) - - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus TIMEOUT (60) + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on] TIMEOUT (60) - condition: ranges: system_gpu_count: @@ -163,6 +163,7 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False] - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True] - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True] + - accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True]