Skip to content

Commit 4792cbf

Browse files
committed
[TRTLLM-10060][feat] Enable attention dp for Nemotron Super v3.
Signed-off-by: nv-guomingz <[email protected]>
1 parent 2de22f1 commit 4792cbf

File tree

7 files changed

+76
-30
lines changed

7 files changed

+76
-30
lines changed

tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class NemotronHHfWeightMapper(HfWeightMapper):
1111

1212
def preprocess_weights(self, weights: dict) -> dict:
1313
config = self.config.pretrained_config
14-
tp_size = self.config.mapping.tp_size
14+
tp_size = 1 if self.config.mapping.enable_attention_dp else self.config.mapping.tp_size
1515
tp_rank = self.config.mapping.tp_rank
1616
d_inner = config.mamba_head_dim * config.mamba_num_heads
1717

tensorrt_llm/_torch/models/modeling_nemotron_h.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ..modules.decoder_layer import DecoderLayer
3333
from ..modules.embedding import Embedding
3434
from ..modules.fused_moe import MoEWeightLoadingMode, create_moe
35-
from ..modules.linear import Linear
35+
from ..modules.linear import Linear, TensorParallelMode
3636
from ..modules.mamba.mamba2_mixer import Mamba2Mixer
3737
from ..modules.mlp import MLP
3838
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -85,8 +85,10 @@ def __init__(
8585
self,
8686
model_config: ModelConfig[NemotronHConfig],
8787
layer_idx: int,
88+
reduce_output: bool = False,
8889
):
8990
config = model_config.pretrained_config
91+
9092
super().__init__(
9193
hidden_size=config.hidden_size,
9294
num_attention_heads=config.num_attention_heads,
@@ -97,6 +99,7 @@ def __init__(
9799
layer_idx=layer_idx,
98100
dtype=config.torch_dtype,
99101
config=model_config,
102+
reduce_output=reduce_output,
100103
)
101104

102105
def forward(
@@ -154,6 +157,7 @@ def __init__(
154157
shared_expert_intermediate_size = (
155158
config.moe_shared_expert_intermediate_size *
156159
config.n_shared_experts)
160+
157161
self.shared_experts = MLP(
158162
hidden_size=config.hidden_size,
159163
intermediate_size=shared_expert_intermediate_size,
@@ -163,7 +167,8 @@ def __init__(
163167
config=model_config,
164168
layer_idx=self.layer_idx,
165169
reduce_output=False,
166-
)
170+
overridden_tp_size=1
171+
if model_config.mapping.enable_attention_dp else None)
167172
# Setup MoE gate.
168173
self.gate = DeepseekV3Gate(
169174
self.hidden_size,
@@ -193,11 +198,14 @@ def __init__(
193198
activation_type=self.activation_type,
194199
)
195200

196-
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
197-
self.allreduce = AllReduce(
198-
mapping=model_config.mapping,
199-
strategy=model_config.allreduce_strategy,
200-
)
201+
if not model_config.mapping.enable_attention_dp:
202+
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
203+
self.allreduce = AllReduce(
204+
mapping=model_config.mapping,
205+
strategy=model_config.allreduce_strategy,
206+
)
207+
else:
208+
self.allreduce = None
201209

202210
# Setup latent projection layers.
203211
# These layers should NOT be TP-sharded to ensure MoE receives
@@ -322,7 +330,11 @@ def __init__(
322330
elif layer_type == "-":
323331
self.mixer = MLPLayer(model_config, layer_idx)
324332
elif layer_type == "*":
325-
self.mixer = TransformerLayer(model_config, layer_idx)
333+
self.mixer = TransformerLayer(
334+
model_config,
335+
layer_idx,
336+
reduce_output=not model_config.mapping.enable_attention_dp
337+
and model_config.mapping.tp_size > 1)
326338
elif layer_type == "E":
327339
self.mixer = NemotronHMOE(model_config,
328340
layer_idx=layer_idx,
@@ -365,12 +377,24 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
365377
aux_stream_list[2],
366378
}
367379

368-
# calculate embeddings
369-
self.embed_tokens = Embedding(
370-
config.vocab_size,
371-
config.hidden_size,
372-
dtype=config.torch_dtype,
373-
)
380+
if model_config.mapping.enable_attention_dp:
381+
# When attention_dp is enabled, we cannot do all_reduce since
382+
# the problem size of different ranks are different.
383+
# So, we don't do parallelism here.
384+
self.embed_tokens = Embedding(
385+
config.vocab_size,
386+
config.hidden_size,
387+
dtype=config.torch_dtype,
388+
)
389+
else:
390+
self.embed_tokens = Embedding(
391+
config.vocab_size,
392+
config.hidden_size,
393+
dtype=config.torch_dtype,
394+
mapping=model_config.mapping,
395+
tensor_parallel_mode=TensorParallelMode.COLUMN,
396+
gather_output=True,
397+
)
374398

375399
# create layers
376400
layers = []

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch import nn
2121

2222
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
23+
from tensorrt_llm.mapping import Mapping
2324

2425
from ...attention_backend import AttentionMetadata
2526
from ...model_config import ModelConfig
@@ -57,8 +58,20 @@ def __init__(
5758

5859
config = config or ModelConfig()
5960
self.mapping = config.mapping
60-
tp_rank = config.mapping.tp_rank
61-
tp_size = config.mapping.tp_size
61+
62+
if config.mapping.enable_attention_dp:
63+
self.mapping = Mapping(
64+
world_size=config.mapping.pp_size,
65+
tp_size=1,
66+
pp_size=config.mapping.pp_size,
67+
rank=config.mapping.rank,
68+
gpus_per_node=config.mapping.gpus_per_node,
69+
enable_attention_dp=True,
70+
)
71+
tp_size = 1
72+
else:
73+
self.mapping = config.mapping
74+
tp_size = config.mapping.tp_size
6275

6376
d_inner = head_dim * nheads
6477
d_in_proj = 2 * d_inner + 2 * n_groups * d_state + nheads
@@ -80,10 +93,6 @@ def __init__(
8093
self.remove_padding = remove_padding
8194
self.apply_silu = apply_silu
8295

83-
# tp
84-
self.tp_size = tp_size
85-
self.tp_rank = tp_rank
86-
8796
# paged state parameters
8897
self.slot_mapping = None
8998
self.is_paged_state = False

tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
self.mamba_ssm_cache_dtype = ssm_cache_dtype
4646

4747
# get tp size
48-
tp_size = mapping.tp_size
48+
tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1
4949

5050
# derive mamba parameters for conv and ssm states
5151
d_inner = head_dim * num_heads

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5276,6 +5276,7 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
52765276
(4, 4, False, True, True),
52775277
(4, 1, False, False, True),
52785278
(4, 4, True, False, True),
5279+
(4, 4, True, True, True),
52795280
(4, 1, True, True, True),
52805281
(4, 4, False, True, False),
52815282
(4, 1, False, False, False),
@@ -5285,9 +5286,6 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
52855286
)
52865287
def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp,
52875288
overlap_scheduler, cuda_graph):
5288-
if attention_dp:
5289-
pytest.skip(
5290-
"Attention DP is not supported for Nemotron-3-Super yet")
52915289

52925290
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
52935291
mamba_ssm_cache_dtype="float32")
@@ -5314,7 +5312,18 @@ def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp,
53145312

53155313
@skip_pre_blackwell
53165314
@pytest.mark.skip_less_mpi_world_size(8)
5317-
def test_nvfp4_8gpus(self):
5315+
@pytest.mark.parametrize(
5316+
"attention_dp",
5317+
[
5318+
False,
5319+
True,
5320+
],
5321+
ids=[
5322+
"attention_dp_off",
5323+
"attention_dp_on",
5324+
],
5325+
)
5326+
def test_nvfp4_8gpus(self, attention_dp):
53185327
# Use this test to track the best performance config.
53195328
# The optimized config is still under investigation.
53205329
# Adding this test as placeholder.
@@ -5329,7 +5338,7 @@ def test_nvfp4_8gpus(self):
53295338
tensor_parallel_size=8,
53305339
moe_expert_parallel_size=8,
53315340
pipeline_parallel_size=1,
5332-
enable_attention_dp=False,
5341+
enable_attention_dp=attention_dp,
53335342
cuda_graph_config=CudaGraphConfig(max_batch_size=32,
53345343
enable_padding=True),
53355344
disable_overlap_scheduler=False,

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4
247247
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False]
248248
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
249249
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
250-
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus
250+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True]
251+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on]
252+
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_off]
253+
251254

252255
# multimodal accuracy tests
253256
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ l0_dgx_b200:
2323
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_mxfp4_mxfp8[enable_configurable_moe-True-8-64-TRTLLM]
2424
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_wfp4a16[enable_configurable_moe-TRTLLM-2880-dtype0]
2525
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-True]
26-
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True]
26+
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True]
2727
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=True]
2828
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False]
2929
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False]
@@ -85,7 +85,7 @@ l0_dgx_b200:
8585
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline] TIMEOUT (60)
8686
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1] TIMEOUT (60)
8787
- accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] TIMEOUT (60)
88-
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus TIMEOUT (60)
88+
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on] TIMEOUT (60)
8989
- condition:
9090
ranges:
9191
system_gpu_count:
@@ -163,6 +163,7 @@ l0_dgx_b200:
163163
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False]
164164
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
165165
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
166+
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True]
166167
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=False]
167168
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False]
168169
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True]

0 commit comments

Comments
 (0)