Skip to content

Commit 7fe7dab

Browse files
committed
[TRTLLM-10060][feat] Enable attention dp for Nemotron Super v3.
Signed-off-by: nv-guomingz <[email protected]>
1 parent 74832a1 commit 7fe7dab

File tree

5 files changed

+91
-36
lines changed

5 files changed

+91
-36
lines changed

tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ def preprocess_weights(self, weights: dict) -> dict:
1313
config = self.config.pretrained_config
1414
tp_size = self.config.mapping.tp_size
1515
tp_rank = self.config.mapping.tp_rank
16+
enable_attention_dp = self.config.mapping.enable_attention_dp
17+
18+
# For Mamba2 layers, use tp_size=1 when attention DP is enabled
19+
mamba_tp_size = 1 if enable_attention_dp else tp_size
20+
mamba_tp_rank = 0 if enable_attention_dp else tp_rank
21+
1622
d_inner = config.mamba_head_dim * config.mamba_num_heads
1723

1824
def _split_mamba2_mixer_in_proj(w: torch.Tensor) -> torch.Tensor:
@@ -24,12 +30,12 @@ def _split_mamba2_mixer_in_proj(w: torch.Tensor) -> torch.Tensor:
2430
],
2531
dim=0)
2632
w = []
27-
for rank in range(tp_size):
28-
in_proj_z_rank = split(in_proj_z, tp_size, rank)
29-
in_proj_x_rank = split(in_proj_x, tp_size, rank)
30-
in_proj_b_rank = split(in_proj_b, tp_size, rank)
31-
in_proj_c_rank = split(in_proj_c, tp_size, rank)
32-
in_proj_dt_rank = split(in_proj_dt, tp_size, rank)
33+
for rank in range(mamba_tp_size):
34+
in_proj_z_rank = split(in_proj_z, mamba_tp_size, rank)
35+
in_proj_x_rank = split(in_proj_x, mamba_tp_size, rank)
36+
in_proj_b_rank = split(in_proj_b, mamba_tp_size, rank)
37+
in_proj_c_rank = split(in_proj_c, mamba_tp_size, rank)
38+
in_proj_dt_rank = split(in_proj_dt, mamba_tp_size, rank)
3339
y = torch.concat([
3440
in_proj_z_rank, in_proj_x_rank, in_proj_b_rank,
3541
in_proj_c_rank, in_proj_dt_rank
@@ -67,16 +73,16 @@ def _split_mamba2_mixer_in_proj(w: torch.Tensor) -> torch.Tensor:
6773
else:
6874
new_weights[key] = weights[name]
6975
elif "A" in key:
70-
w = split(weights[name], tp_size, tp_rank)
76+
w = split(weights[name], mamba_tp_size, mamba_tp_rank)
7177
w = w.to(torch.float32)
7278
w = -torch.exp(w)
7379
new_weights[key] = w
7480
elif "D" in key:
75-
w = split(weights[name], tp_size, tp_rank)
81+
w = split(weights[name], mamba_tp_size, mamba_tp_rank)
7682
w = w.to(torch.float32)
7783
new_weights[key] = w
7884
elif "dt_bias" in key:
79-
w = split(weights[name], tp_size, tp_rank)
85+
w = split(weights[name], mamba_tp_size, mamba_tp_rank)
8086
w = w.to(torch.float32)
8187
new_weights[key] = w
8288
elif "mixer.in_proj" in key:
@@ -91,16 +97,16 @@ def _split_mamba2_mixer_in_proj(w: torch.Tensor) -> torch.Tensor:
9197
w, [d_inner, n_groups * d_state, n_groups * d_state], dim=0)
9298

9399
w = []
94-
for rank in range(tp_size):
95-
conv_x_rank = split(conv_x, tp_size, rank)
96-
conv_b_rank = split(conv_b, tp_size, rank)
97-
conv_c_rank = split(conv_c, tp_size, rank)
100+
for rank in range(mamba_tp_size):
101+
conv_x_rank = split(conv_x, mamba_tp_size, rank)
102+
conv_b_rank = split(conv_b, mamba_tp_size, rank)
103+
conv_c_rank = split(conv_c, mamba_tp_size, rank)
98104
y = torch.concat([conv_x_rank, conv_b_rank, conv_c_rank])
99105
w.append(y)
100106
w = torch.concat(w).contiguous()
101107
new_weights[key] = w
102108
elif "mixer.norm.weight" in key:
103-
w = split(weights[name], tp_size, tp_rank)
109+
w = split(weights[name], mamba_tp_size, mamba_tp_rank)
104110
new_weights[key] = w
105111
# Remap MoE expert weights.
106112
elif "mixer.experts." in key:

tensorrt_llm/_torch/models/modeling_nemotron_h.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ..modules.decoder_layer import DecoderLayer
3333
from ..modules.embedding import Embedding
3434
from ..modules.fused_moe import MoEWeightLoadingMode, create_moe
35-
from ..modules.linear import Linear
35+
from ..modules.linear import Linear, TensorParallelMode
3636
from ..modules.mamba.mamba2_mixer import Mamba2Mixer
3737
from ..modules.mlp import MLP
3838
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -85,8 +85,10 @@ def __init__(
8585
self,
8686
model_config: ModelConfig[NemotronHConfig],
8787
layer_idx: int,
88+
reduce_output: bool = False,
8889
):
8990
config = model_config.pretrained_config
91+
9092
super().__init__(
9193
hidden_size=config.hidden_size,
9294
num_attention_heads=config.num_attention_heads,
@@ -97,6 +99,7 @@ def __init__(
9799
layer_idx=layer_idx,
98100
dtype=config.torch_dtype,
99101
config=model_config,
102+
reduce_output=reduce_output,
100103
)
101104

102105
def forward(
@@ -154,6 +157,7 @@ def __init__(
154157
shared_expert_intermediate_size = (
155158
config.moe_shared_expert_intermediate_size *
156159
config.n_shared_experts)
160+
157161
self.shared_experts = MLP(
158162
hidden_size=config.hidden_size,
159163
intermediate_size=shared_expert_intermediate_size,
@@ -193,11 +197,14 @@ def __init__(
193197
activation_type=self.activation_type,
194198
)
195199

196-
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
197-
self.allreduce = AllReduce(
198-
mapping=model_config.mapping,
199-
strategy=model_config.allreduce_strategy,
200-
)
200+
if not model_config.mapping.enable_attention_dp:
201+
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
202+
self.allreduce = AllReduce(
203+
mapping=model_config.mapping,
204+
strategy=model_config.allreduce_strategy,
205+
)
206+
else:
207+
self.allreduce = None
201208

202209
# Setup latent projection layers.
203210
if self.use_latent_moe:
@@ -314,7 +321,11 @@ def __init__(
314321
elif layer_type == "-":
315322
self.mixer = MLPLayer(model_config, layer_idx)
316323
elif layer_type == "*":
317-
self.mixer = TransformerLayer(model_config, layer_idx)
324+
self.mixer = TransformerLayer(
325+
model_config,
326+
layer_idx,
327+
reduce_output=not model_config.mapping.enable_attention_dp
328+
and model_config.mapping.tp_size > 1)
318329
elif layer_type == "E":
319330
self.mixer = NemotronHMOE(model_config,
320331
layer_idx=layer_idx,
@@ -357,12 +368,24 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
357368
aux_stream_list[2],
358369
}
359370

360-
# calculate embeddings
361-
self.embed_tokens = Embedding(
362-
config.vocab_size,
363-
config.hidden_size,
364-
dtype=config.torch_dtype,
365-
)
371+
if model_config.mapping.enable_attention_dp:
372+
# When attention_dp is enabled, we cannot do all_reduce since
373+
# the problem size of different ranks are different.
374+
# So, we don't do parallelism here.
375+
self.embed_tokens = Embedding(
376+
config.vocab_size,
377+
config.hidden_size,
378+
dtype=config.torch_dtype,
379+
)
380+
else:
381+
self.embed_tokens = Embedding(
382+
config.vocab_size,
383+
config.hidden_size,
384+
dtype=config.torch_dtype,
385+
mapping=model_config.mapping,
386+
tensor_parallel_mode=TensorParallelMode.COLUMN,
387+
gather_output=True,
388+
)
366389

367390
# create layers
368391
layers = []

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,23 @@ def __init__(
5757

5858
config = config or ModelConfig()
5959
self.mapping = config.mapping
60-
tp_rank = config.mapping.tp_rank
61-
tp_size = config.mapping.tp_size
60+
61+
if config.mapping.enable_attention_dp:
62+
from tensorrt_llm.mapping import Mapping
63+
self.mapping = Mapping(
64+
world_size=config.mapping.pp_size,
65+
tp_size=1,
66+
pp_size=config.mapping.pp_size,
67+
rank=config.mapping.rank,
68+
gpus_per_node=config.mapping.gpus_per_node,
69+
enable_attention_dp=True,
70+
)
71+
tp_size = 1
72+
tp_rank = 0
73+
else:
74+
self.mapping = config.mapping
75+
tp_size = config.mapping.tp_size
76+
tp_rank = config.mapping.tp_rank
6277

6378
d_inner = head_dim * nheads
6479
d_in_proj = 2 * d_inner + 2 * n_groups * d_state + nheads

tensorrt_llm/_torch/modules/mlp.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
from torch import nn
66

7+
from tensorrt_llm.mapping import Mapping
8+
79
from ..model_config import ModelConfig
810
from ..peft.lora.layer import LoraLayer, LoraModuleType
911
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
@@ -21,24 +23,33 @@ def __init__(self,
2123
config: Optional[ModelConfig] = None,
2224
layer_idx: Optional[int] = None,
2325
reduce_output: bool = True):
24-
26+
if config.mapping.enable_attention_dp:
27+
mapping = Mapping(
28+
world_size=config.mapping.pp_size,
29+
tp_size=1,
30+
pp_size=config.mapping.pp_size,
31+
rank=config.mapping.rank,
32+
gpus_per_node=config.mapping.gpus_per_node,
33+
enable_attention_dp=True,
34+
)
35+
else:
36+
mapping = config.mapping
2537
super().__init__()
2638
self.layer_idx = layer_idx
2739
self.hidden_size = hidden_size
2840
self.intermediate_size = intermediate_size
2941
self.activation = activation
3042

3143
config = config or ModelConfig()
32-
self.up_lora = LoraLayer(
33-
[LoraModuleType.MLP_H_TO_4H],
34-
[self.intermediate_size // config.mapping.tp_size])
44+
self.up_lora = LoraLayer([LoraModuleType.MLP_H_TO_4H],
45+
[self.intermediate_size // mapping.tp_size])
3546

3647
self.up_proj = Linear(
3748
self.hidden_size,
3849
self.intermediate_size,
3950
bias=bias,
4051
dtype=dtype,
41-
mapping=config.mapping,
52+
mapping=mapping,
4253
tensor_parallel_mode=TensorParallelMode.COLUMN,
4354
weights_loading_config=WeightsLoadingConfig(
4455
weight_mode=WeightMode.VANILLA),
@@ -55,7 +66,7 @@ def __init__(self,
5566
self.hidden_size,
5667
bias=bias,
5768
dtype=dtype,
58-
mapping=config.mapping,
69+
mapping=mapping,
5970
tensor_parallel_mode=TensorParallelMode.ROW,
6071
quant_config=config.get_quant_config(),
6172
skip_create_weights_in_init=config.skip_create_weights_in_init,

tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
self.mamba_ssm_cache_dtype = ssm_cache_dtype
4646

4747
# get tp size
48-
tp_size = mapping.tp_size
48+
tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1
4949

5050
# derive mamba parameters for conv and ssm states
5151
d_inner = head_dim * num_heads

0 commit comments

Comments
 (0)