Skip to content

Commit 1077cea

Browse files
committed
[TRTLLM-10060][feat] Enable attention dp for Nemotron Super v3.
Signed-off-by: nv-guomingz <[email protected]>
1 parent e033129 commit 1077cea

File tree

8 files changed

+90
-35
lines changed

8 files changed

+90
-35
lines changed

tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class NemotronHHfWeightMapper(HfWeightMapper):
1111

1212
def preprocess_weights(self, weights: dict) -> dict:
1313
config = self.config.pretrained_config
14-
tp_size = self.config.mapping.tp_size
14+
tp_size = 1 if self.config.mapping.enable_attention_dp else self.config.mapping.tp_size
1515
tp_rank = self.config.mapping.tp_rank
1616
d_inner = config.mamba_head_dim * config.mamba_num_heads
1717

tensorrt_llm/_torch/models/modeling_nemotron_h.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ..modules.decoder_layer import DecoderLayer
3333
from ..modules.embedding import Embedding
3434
from ..modules.fused_moe import MoEWeightLoadingMode, create_moe
35-
from ..modules.linear import Linear
35+
from ..modules.linear import Linear, TensorParallelMode
3636
from ..modules.mamba.mamba2_mixer import Mamba2Mixer
3737
from ..modules.mlp import MLP
3838
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -85,8 +85,10 @@ def __init__(
8585
self,
8686
model_config: ModelConfig[NemotronHConfig],
8787
layer_idx: int,
88+
reduce_output: bool = False,
8889
):
8990
config = model_config.pretrained_config
91+
9092
super().__init__(
9193
hidden_size=config.hidden_size,
9294
num_attention_heads=config.num_attention_heads,
@@ -97,6 +99,7 @@ def __init__(
9799
layer_idx=layer_idx,
98100
dtype=config.torch_dtype,
99101
config=model_config,
102+
reduce_output=reduce_output,
100103
)
101104

102105
def forward(
@@ -154,6 +157,7 @@ def __init__(
154157
shared_expert_intermediate_size = (
155158
config.moe_shared_expert_intermediate_size *
156159
config.n_shared_experts)
160+
157161
self.shared_experts = MLP(
158162
hidden_size=config.hidden_size,
159163
intermediate_size=shared_expert_intermediate_size,
@@ -193,11 +197,14 @@ def __init__(
193197
activation_type=self.activation_type,
194198
)
195199

196-
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
197-
self.allreduce = AllReduce(
198-
mapping=model_config.mapping,
199-
strategy=model_config.allreduce_strategy,
200-
)
200+
if not model_config.mapping.enable_attention_dp:
201+
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
202+
self.allreduce = AllReduce(
203+
mapping=model_config.mapping,
204+
strategy=model_config.allreduce_strategy,
205+
)
206+
else:
207+
self.allreduce = None
201208

202209
# Setup latent projection layers.
203210
# These layers should NOT be TP-sharded to ensure MoE receives
@@ -322,7 +329,11 @@ def __init__(
322329
elif layer_type == "-":
323330
self.mixer = MLPLayer(model_config, layer_idx)
324331
elif layer_type == "*":
325-
self.mixer = TransformerLayer(model_config, layer_idx)
332+
self.mixer = TransformerLayer(
333+
model_config,
334+
layer_idx,
335+
reduce_output=not model_config.mapping.enable_attention_dp
336+
and model_config.mapping.tp_size > 1)
326337
elif layer_type == "E":
327338
self.mixer = NemotronHMOE(model_config,
328339
layer_idx=layer_idx,
@@ -365,12 +376,24 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
365376
aux_stream_list[2],
366377
}
367378

368-
# calculate embeddings
369-
self.embed_tokens = Embedding(
370-
config.vocab_size,
371-
config.hidden_size,
372-
dtype=config.torch_dtype,
373-
)
379+
if model_config.mapping.enable_attention_dp:
380+
# When attention_dp is enabled, we cannot do all_reduce since
381+
# the problem size of different ranks are different.
382+
# So, we don't do parallelism here.
383+
self.embed_tokens = Embedding(
384+
config.vocab_size,
385+
config.hidden_size,
386+
dtype=config.torch_dtype,
387+
)
388+
else:
389+
self.embed_tokens = Embedding(
390+
config.vocab_size,
391+
config.hidden_size,
392+
dtype=config.torch_dtype,
393+
mapping=model_config.mapping,
394+
tensor_parallel_mode=TensorParallelMode.COLUMN,
395+
gather_output=True,
396+
)
374397

375398
# create layers
376399
layers = []

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch import nn
2121

2222
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
23+
from tensorrt_llm.mapping import Mapping
2324

2425
from ...attention_backend import AttentionMetadata
2526
from ...model_config import ModelConfig
@@ -57,8 +58,20 @@ def __init__(
5758

5859
config = config or ModelConfig()
5960
self.mapping = config.mapping
60-
tp_rank = config.mapping.tp_rank
61-
tp_size = config.mapping.tp_size
61+
62+
if config.mapping.enable_attention_dp:
63+
self.mapping = Mapping(
64+
world_size=config.mapping.pp_size,
65+
tp_size=1,
66+
pp_size=config.mapping.pp_size,
67+
rank=config.mapping.rank,
68+
gpus_per_node=config.mapping.gpus_per_node,
69+
enable_attention_dp=True,
70+
)
71+
tp_size = 1
72+
else:
73+
self.mapping = config.mapping
74+
tp_size = config.mapping.tp_size
6275

6376
d_inner = head_dim * nheads
6477
d_in_proj = 2 * d_inner + 2 * n_groups * d_state + nheads
@@ -80,10 +93,6 @@ def __init__(
8093
self.remove_padding = remove_padding
8194
self.apply_silu = apply_silu
8295

83-
# tp
84-
self.tp_size = tp_size
85-
self.tp_rank = tp_rank
86-
8796
# paged state parameters
8897
self.slot_mapping = None
8998
self.is_paged_state = False

tensorrt_llm/_torch/modules/mlp.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
from torch import nn
66

7+
from tensorrt_llm.mapping import Mapping
8+
79
from ..model_config import ModelConfig
810
from ..peft.lora.layer import LoraLayer, LoraModuleType
911
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
@@ -21,24 +23,33 @@ def __init__(self,
2123
config: Optional[ModelConfig] = None,
2224
layer_idx: Optional[int] = None,
2325
reduce_output: bool = True):
24-
26+
if config.mapping.enable_attention_dp:
27+
mapping = Mapping(
28+
world_size=config.mapping.pp_size,
29+
tp_size=1,
30+
pp_size=config.mapping.pp_size,
31+
rank=config.mapping.rank,
32+
gpus_per_node=config.mapping.gpus_per_node,
33+
enable_attention_dp=True,
34+
)
35+
else:
36+
mapping = config.mapping
2537
super().__init__()
2638
self.layer_idx = layer_idx
2739
self.hidden_size = hidden_size
2840
self.intermediate_size = intermediate_size
2941
self.activation = activation
3042

3143
config = config or ModelConfig()
32-
self.up_lora = LoraLayer(
33-
[LoraModuleType.MLP_H_TO_4H],
34-
[self.intermediate_size // config.mapping.tp_size])
44+
self.up_lora = LoraLayer([LoraModuleType.MLP_H_TO_4H],
45+
[self.intermediate_size // mapping.tp_size])
3546

3647
self.up_proj = Linear(
3748
self.hidden_size,
3849
self.intermediate_size,
3950
bias=bias,
4051
dtype=dtype,
41-
mapping=config.mapping,
52+
mapping=mapping,
4253
tensor_parallel_mode=TensorParallelMode.COLUMN,
4354
weights_loading_config=WeightsLoadingConfig(
4455
weight_mode=WeightMode.VANILLA),
@@ -55,7 +66,7 @@ def __init__(self,
5566
self.hidden_size,
5667
bias=bias,
5768
dtype=dtype,
58-
mapping=config.mapping,
69+
mapping=mapping,
5970
tensor_parallel_mode=TensorParallelMode.ROW,
6071
quant_config=config.get_quant_config(),
6172
skip_create_weights_in_init=config.skip_create_weights_in_init,

tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
self.mamba_ssm_cache_dtype = ssm_cache_dtype
4646

4747
# get tp size
48-
tp_size = mapping.tp_size
48+
tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1
4949

5050
# derive mamba parameters for conv and ssm states
5151
d_inner = head_dim * num_heads

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5273,6 +5273,7 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
52735273
(4, 4, False, True, True),
52745274
(4, 1, False, False, True),
52755275
(4, 4, True, False, True),
5276+
(4, 4, True, True, True),
52765277
(4, 1, True, True, True),
52775278
(4, 4, False, True, False),
52785279
(4, 1, False, False, False),
@@ -5282,9 +5283,6 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
52825283
)
52835284
def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp,
52845285
overlap_scheduler, cuda_graph):
5285-
if attention_dp:
5286-
pytest.skip(
5287-
"Attention DP is not supported for Nemotron-3-Super yet")
52885286

52895287
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
52905288
mamba_ssm_cache_dtype="float32")
@@ -5311,7 +5309,18 @@ def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp,
53115309

53125310
@skip_pre_blackwell
53135311
@pytest.mark.skip_less_mpi_world_size(8)
5314-
def test_nvfp4_8gpus(self):
5312+
@pytest.mark.parametrize(
5313+
"attention_dp",
5314+
[
5315+
False,
5316+
True,
5317+
],
5318+
ids=[
5319+
"attention_dp_off",
5320+
"attention_dp_on",
5321+
],
5322+
)
5323+
def test_nvfp4_8gpus(self, attention_dp):
53155324
# Use this test to track the best performance config.
53165325
# The optimized config is still under investigation.
53175326
# Adding this test as placeholder.
@@ -5326,7 +5335,7 @@ def test_nvfp4_8gpus(self):
53265335
tensor_parallel_size=8,
53275336
moe_expert_parallel_size=8,
53285337
pipeline_parallel_size=1,
5329-
enable_attention_dp=False,
5338+
enable_attention_dp=attention_dp,
53305339
cuda_graph_config=CudaGraphConfig(max_batch_size=32,
53315340
enable_padding=True),
53325341
disable_overlap_scheduler=False,

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,9 @@ accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4
691691
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False]
692692
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
693693
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
694-
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus
694+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True]
695+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on]
696+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_off]
695697

696698
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
697699
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ l0_dgx_b200:
2323
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_mxfp4_mxfp8[enable_configurable_moe-True-8-64-TRTLLM]
2424
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_wfp4a16[enable_configurable_moe-TRTLLM-2880-dtype0]
2525
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-True]
26-
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True]
26+
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True]
2727
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=True]
2828
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False]
2929
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False]
@@ -85,7 +85,7 @@ l0_dgx_b200:
8585
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline] TIMEOUT (60)
8686
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1] TIMEOUT (60)
8787
- accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] TIMEOUT (60)
88-
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus TIMEOUT (60)
88+
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on] TIMEOUT (60)
8989
- condition:
9090
ranges:
9191
system_gpu_count:
@@ -163,6 +163,7 @@ l0_dgx_b200:
163163
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False]
164164
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
165165
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
166+
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True]
166167
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=False]
167168
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False]
168169
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True]

0 commit comments

Comments
 (0)