Skip to content

Commit 14554ab

Browse files
authored
[None][feat] Support multi-gpu running for nemotron-v3-nano and super (#10118)
Signed-off-by: Wanli Jiang <[email protected]>
1 parent 819d03f commit 14554ab

File tree

6 files changed

+122
-40
lines changed

6 files changed

+122
-40
lines changed

tensorrt_llm/_torch/models/checkpoints/hf/nemotron_h_weight_mapper.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,30 @@ def preprocess_weights(self, weights: dict) -> dict:
1515
tp_rank = self.config.mapping.tp_rank
1616
d_inner = config.mamba_head_dim * config.mamba_num_heads
1717

18+
def _split_mamba2_mixer_in_proj(w: torch.Tensor) -> torch.Tensor:
19+
# Special handling for Mamba2 mixer in_proj.weights and scales.
20+
in_proj_z, in_proj_x, in_proj_b, in_proj_c, in_proj_dt = torch.split(
21+
w, [
22+
d_inner, d_inner, n_groups * d_state, n_groups * d_state,
23+
nheads
24+
],
25+
dim=0)
26+
w = []
27+
for rank in range(tp_size):
28+
in_proj_z_rank = split(in_proj_z, tp_size, rank)
29+
in_proj_x_rank = split(in_proj_x, tp_size, rank)
30+
in_proj_b_rank = split(in_proj_b, tp_size, rank)
31+
in_proj_c_rank = split(in_proj_c, tp_size, rank)
32+
in_proj_dt_rank = split(in_proj_dt, tp_size, rank)
33+
y = torch.concat([
34+
in_proj_z_rank, in_proj_x_rank, in_proj_b_rank,
35+
in_proj_c_rank, in_proj_dt_rank
36+
])
37+
w.append(y)
38+
w = torch.concat(w).contiguous()
39+
return w
40+
41+
is_nvfp4 = self.config.quant_config.quant_algo == "NVFP4"
1842
n_groups = config.n_groups
1943
d_state = config.ssm_state_size
2044
nheads = config.mamba_num_heads
@@ -36,7 +60,12 @@ def preprocess_weights(self, weights: dict) -> dict:
3660

3761
if ("mixer.in_proj" in key
3862
or "mixer.out_proj" in key) and "_scale" in key:
39-
new_weights[key] = weights[name]
63+
# Special handing for nvfp4 Mamba2 mixer in_proj.weight_scale.
64+
if is_nvfp4 and "in_proj.weight_scale_2" not in key and "in_proj.weight_scale" in key:
65+
new_weights[key] = _split_mamba2_mixer_in_proj(
66+
weights[name])
67+
else:
68+
new_weights[key] = weights[name]
4069
elif "A" in key:
4170
w = split(weights[name], tp_size, tp_rank)
4271
w = w.to(torch.float32)
@@ -51,29 +80,7 @@ def preprocess_weights(self, weights: dict) -> dict:
5180
w = w.to(torch.float32)
5281
new_weights[key] = w
5382
elif "mixer.in_proj" in key:
54-
w = weights[name]
55-
in_proj_z, in_proj_x, in_proj_b, in_proj_c, in_proj_dt = torch.split(
56-
w, [
57-
d_inner, d_inner, n_groups * d_state,
58-
n_groups * d_state, nheads
59-
],
60-
dim=0)
61-
62-
w = []
63-
for rank in range(tp_size):
64-
in_proj_z_rank = split(in_proj_z, tp_size, rank)
65-
in_proj_x_rank = split(in_proj_x, tp_size, rank)
66-
in_proj_b_rank = split(in_proj_b, tp_size, rank)
67-
in_proj_c_rank = split(in_proj_c, tp_size, rank)
68-
in_proj_dt_rank = split(in_proj_dt, tp_size, rank)
69-
y = torch.concat([
70-
in_proj_z_rank, in_proj_x_rank, in_proj_b_rank,
71-
in_proj_c_rank, in_proj_dt_rank
72-
])
73-
w.append(y)
74-
75-
w = torch.concat(w).contiguous()
76-
new_weights[key] = w
83+
new_weights[key] = _split_mamba2_mixer_in_proj(weights[name])
7784
elif "conv1d" in key:
7885
w = weights[name]
7986
# removing dim(1) because we are using Linear to store conv1d weights
@@ -110,19 +117,21 @@ def preprocess_weights(self, weights: dict) -> dict:
110117
elif "weight_scale" in key:
111118
# NVFP4 case.
112119
if weights[name].shape:
113-
new_weights[w3_key] = weights[
114-
name][:weights[name].shape[0] // 2]
115-
new_weights[w1_key] = weights[name][
116-
weights[name].shape[0] // 2:]
120+
# w3 weight (gate_proj) scale should be empty for Nemotron-H MoE model.
121+
# Use [:0] to keep the same input dimension as the other weights.
122+
# The w3 weight_scale shape should be [0, input_dim].
123+
new_weights[w3_key] = weights[name][:0]
124+
new_weights[w1_key] = weights[name]
117125
# FP8 case.
118126
else:
119127
new_weights[w3_key] = weights[name]
120128
new_weights[w1_key] = weights[name]
121129
else:
122-
new_weights[w3_key] = weights[name][:weights[name].
123-
shape[0] // 2]
124-
new_weights[w1_key] = weights[name][weights[name].
125-
shape[0] // 2:]
130+
# w3 weight (gate_proj) should be empty for Nemotron-H MoE model.
131+
# Use [:0] to keep the same input dimension as the other weights.
132+
# The w3 weight shape should be [0, input_dim].
133+
new_weights[w3_key] = weights[name][:0]
134+
new_weights[w1_key] = weights[name]
126135
elif "down_proj" in key:
127136
key = key.replace("down_proj", "w2")
128137
new_weights[key] = weights[name]

tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,17 @@ def _duplicate_kv_weights(self, module: nn.Module, new_name: str,
6969
num_kv_heads = kv_shape * 2 // self._head_dim
7070
else:
7171
num_kv_heads = kv_shape // self._head_dim
72+
73+
duplicated_keys = ["weight", "bias"]
74+
if module.quant_config.quant_mode.has_nvfp4():
75+
duplicated_keys.append("weight_scale")
76+
7277
processed_weights = {
7378
k:
7479
self._duplicate_kv(weight=v[:],
7580
num_kv_heads=num_kv_heads,
7681
tensor_parallel_size=self._tp_size)
77-
if k in ["weight", "bias"] else v
82+
if k in duplicated_keys else v
7883
for k, v in weights.items()
7984
}
8085
return processed_weights

tensorrt_llm/_torch/models/modeling_nemotron_h.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorrt_llm._torch.utils import ActivationType, relu2
2727

2828
from ..attention_backend import AttentionMetadata
29+
from ..distributed import AllReduce
2930
from ..model_config import ModelConfig
3031
from ..modules.attention import Attention
3132
from ..modules.decoder_layer import DecoderLayer
@@ -124,7 +125,7 @@ def __init__(
124125
from .modeling_deepseekv3 import DeepseekV3Gate
125126

126127
self.activation_type = ActivationType.Relu2
127-
self.reduce_results = True
128+
self.reduce_results = False
128129

129130
config = model_config.pretrained_config
130131
self.hidden_dim = config.hidden_size
@@ -144,6 +145,7 @@ def __init__(
144145
self.top_k = config.num_experts_per_tok
145146
self.enable_attention_dp = model_config.mapping.enable_attention_dp
146147
self.routed_scaling_factor = config.routed_scaling_factor
148+
self.mapping = model_config.mapping
147149

148150
# Setup shared expert MLP.
149151
if config.n_shared_experts is None or config.n_shared_experts == 0:
@@ -160,6 +162,7 @@ def __init__(
160162
dtype=config.torch_dtype,
161163
config=model_config,
162164
layer_idx=self.layer_idx,
165+
reduce_output=False,
163166
)
164167
# Setup MoE gate.
165168
self.gate = DeepseekV3Gate(
@@ -190,6 +193,12 @@ def __init__(
190193
activation_type=self.activation_type,
191194
)
192195

196+
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
197+
self.allreduce = AllReduce(
198+
mapping=model_config.mapping,
199+
strategy=model_config.allreduce_strategy,
200+
)
201+
193202
# Setup latent projection layers.
194203
if self.use_latent_moe:
195204
self.fc1_latent_proj = Linear(
@@ -223,6 +232,7 @@ def forward(
223232
assert hidden_states.shape[-1] == self.hidden_dim
224233
orig_shape = hidden_states.shape
225234
hidden_states = hidden_states.view(-1, self.hidden_dim)
235+
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
226236

227237
def _compute_shared_output():
228238
if self.shared_experts is not None:
@@ -239,7 +249,6 @@ def _compute_routed_output():
239249
routed_hidden_states = self.fc1_latent_proj(
240250
routed_hidden_states)
241251

242-
all_rank_num_tokens = attn_metadata.all_rank_num_tokens
243252
final_hidden_states = self.experts(
244253
routed_hidden_states,
245254
router_logits,
@@ -258,6 +267,10 @@ def _compute_routed_output():
258267

259268
final_hidden_states = shared_output + routed_output
260269

270+
# Perform all-reduce after combining outputs for multi-GPU support.
271+
if not self.enable_attention_dp and self.mapping.tp_size > 1:
272+
final_hidden_states = self.allreduce(final_hidden_states)
273+
261274
return final_hidden_states.view(orig_shape)
262275

263276

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,12 +475,21 @@ def load_expert_w3_w1_weight(self,
475475
TensorParallelMode.COLUMN,
476476
device=device) if w3_weight is not None else None
477477

478-
dst_w3_weight, dst_w1_weight = dst_w3_w1_weight.chunk(2, dim=0)
478+
src_w3_size_shard = w3_weight_shard.shape[
479+
0] if w3_weight_shard is not None else 0
480+
src_w1_size_shard = w1_weight_shard.shape[
481+
0] if w1_weight_shard is not None else 0
479482
if w1_weight is not None:
483+
dst_w1_weight = dst_w3_w1_weight.narrow(dim=0,
484+
start=src_w3_size_shard,
485+
length=src_w1_size_shard)
480486
dst_w1_weight.copy_(w1_weight_shard.contiguous().view(
481487
dst_w3_w1_weight.dtype),
482488
non_blocking=True)
483489
if w3_weight is not None:
490+
dst_w3_weight = dst_w3_w1_weight.narrow(dim=0,
491+
start=0,
492+
length=src_w3_size_shard)
484493
dst_w3_weight.copy_(w3_weight_shard.contiguous().view(
485494
dst_w3_w1_weight.dtype),
486495
non_blocking=True)
@@ -701,6 +710,37 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
701710
module.fc2_dequant.data.copy_(tmp_w2_weight_scale * max_fc2_input_scale)
702711
module.fc31_input_dequant.data.copy_(max_fc31_input_scale)
703712

713+
def post_load_weights(self, module):
714+
super().post_load_weights(module)
715+
716+
# Padding weights to meet FP8 GEMM alignment requirements.
717+
def _maybe_padding_weights(tensor: torch.Tensor, row_alignment: int,
718+
col_alignment: int):
719+
row_pad_size = (row_alignment - tensor.size(1)) % row_alignment
720+
col_pad_size = (col_alignment - tensor.size(2)) % col_alignment
721+
is_padded = row_pad_size != 0 or col_pad_size != 0
722+
if is_padded:
723+
return F.pad(tensor, (0, col_pad_size, 0, row_pad_size),
724+
mode='constant',
725+
value=0), is_padded
726+
return tensor, is_padded
727+
728+
if getattr(module, "moe_backend", None) == "CUTLASS":
729+
cutlass_fp8_row_alignment, cutlass_fp8_col_alignment = 32, 16
730+
padded_w3_w1_weight, is_padded_w3_w1_weight = _maybe_padding_weights(
731+
module.w3_w1_weight, cutlass_fp8_row_alignment,
732+
cutlass_fp8_col_alignment)
733+
# Use `row_alignment` for `w2_weight.shape[2]` to match the shape of `w3_w1_weight.shape[1]`.
734+
padded_w2_weight, is_padded_w2_weight = _maybe_padding_weights(
735+
module.w2_weight, cutlass_fp8_row_alignment,
736+
cutlass_fp8_row_alignment)
737+
if is_padded_w3_w1_weight:
738+
module.w3_w1_weight = nn.Parameter(padded_w3_w1_weight,
739+
requires_grad=False)
740+
if is_padded_w2_weight:
741+
module.w2_weight = nn.Parameter(padded_w2_weight,
742+
requires_grad=False)
743+
704744

705745
class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
706746

@@ -2079,10 +2119,12 @@ def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int,
20792119

20802120
def create_weights(self, module: torch.nn.Module):
20812121
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
2082-
block_scales_vec_size = torch.iinfo(self.block_scales_dtype).bits // 8
2122+
self.block_scales_vec_size = torch.iinfo(
2123+
self.block_scales_dtype).bits // 8
20832124

20842125
super().create_weights(module, self.weight_dtype, weight_vec_size,
2085-
self.block_scales_dtype, block_scales_vec_size)
2126+
self.block_scales_dtype,
2127+
self.block_scales_vec_size)
20862128

20872129
def load_expert_w3_w1_weight_scale_nvfp4(
20882130
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
@@ -2131,6 +2173,16 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
21312173
module.tp_rank,
21322174
TensorParallelMode.ROW,
21332175
device=device)
2176+
# Padding w2_weight_scale (dtype=float8_e4m3fn) to match the shape of dst_w2_weight_scale (dtype=float32)
2177+
src_w2_scale_size = w2_weight_scale.shape[1]
2178+
adjusted_dst_w2_scale_size = dst_w2_weight_scale.shape[
2179+
1] * self.block_scales_vec_size
2180+
assert adjusted_dst_w2_scale_size >= src_w2_scale_size, "adjusted_dst_w2_scale_size must be greater than or equal to src_w2_scale_size"
2181+
if adjusted_dst_w2_scale_size > src_w2_scale_size:
2182+
w2_weight_scale = torch.nn.functional.pad(
2183+
w2_weight_scale,
2184+
(0, adjusted_dst_w2_scale_size - src_w2_scale_size), "constant",
2185+
0).contiguous()
21342186

21352187
cast_w2_weight_scale = w2_weight_scale.view(dst_w2_weight_scale.dtype)
21362188
cast_w2_weight_scale = self._maybe_padding_shape(

tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def forward(
158158
hidden_states: torch.Tensor,
159159
attn_metadata: AttentionMetadata,
160160
mamba_metadata: Mamba2Metadata,
161+
**kwargs,
161162
) -> torch.Tensor:
162163

163164
# calculate split size

tensorrt_llm/_torch/modules/mlp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def __init__(self,
1919
activation: Callable[[torch.Tensor], torch.Tensor] = None,
2020
dtype: Optional[torch.dtype] = None,
2121
config: Optional[ModelConfig] = None,
22-
layer_idx: Optional[int] = None):
22+
layer_idx: Optional[int] = None,
23+
reduce_output: bool = True):
2324

2425
super().__init__()
2526
self.layer_idx = layer_idx
@@ -60,7 +61,8 @@ def __init__(self,
6061
skip_create_weights_in_init=config.skip_create_weights_in_init,
6162
lora=self.down_lora,
6263
allreduce_strategy=config.allreduce_strategy,
63-
force_dynamic_quantization=config.force_dynamic_quantization)
64+
force_dynamic_quantization=config.force_dynamic_quantization,
65+
reduce_output=reduce_output)
6466

6567
def forward(
6668
self,

0 commit comments

Comments
 (0)