Skip to content

Commit 5bf20d2

Browse files
committed
adapt nemotron modeling to support super v3
Fix MLP dims to support latent dimension Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent 14554ab commit 5bf20d2

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,23 @@ def forward(self, hidden_states):
250250

251251
# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
252252
class NemotronHMLP(nn.Module):
253-
def __init__(self, config, layer_idx: int, intermediate_size: Optional[int] = None):
253+
def __init__(
254+
self,
255+
config,
256+
layer_idx: int,
257+
intermediate_size: Optional[int] = None,
258+
is_expert: bool = False,
259+
):
254260
super().__init__()
255261
self.config = config
256262
self.layer_idx = layer_idx
257263
self.hidden_size = config.hidden_size
258264
self.intermediate_size = intermediate_size or config.intermediate_size
259-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
260-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
265+
# Use latent size for expert MLPs if provided by config (required for SuperV3)
266+
use_latent_size = (getattr(self.config, "moe_latent_size", None) is not None) and is_expert
267+
input_size = self.config.moe_latent_size if use_latent_size else self.hidden_size
268+
self.up_proj = nn.Linear(input_size, self.intermediate_size, bias=config.mlp_bias)
269+
self.down_proj = nn.Linear(self.intermediate_size, input_size, bias=config.mlp_bias)
261270
self.act_fn = ACT2FN[config.mlp_hidden_act]
262271

263272
def forward(self, x):
@@ -271,7 +280,10 @@ def __init__(self, config, layer_idx: Optional[int] = None):
271280
self.experts = nn.ModuleList(
272281
[
273282
NemotronHMLP(
274-
config, intermediate_size=config.moe_intermediate_size, layer_idx=layer_idx
283+
config,
284+
layer_idx=layer_idx,
285+
intermediate_size=config.moe_intermediate_size,
286+
is_expert=True,
275287
)
276288
for _ in range(config.n_routed_experts)
277289
]
@@ -281,7 +293,19 @@ def __init__(self, config, layer_idx: Optional[int] = None):
281293
config=config,
282294
intermediate_size=config.moe_shared_expert_intermediate_size,
283295
layer_idx=layer_idx,
296+
is_expert=False,
284297
)
298+
# Add latent projections when using latent MoE (required for SuperV3)
299+
if getattr(config, "moe_latent_size", None) is not None:
300+
self.fc1_latent_proj = nn.Linear(
301+
config.hidden_size, config.moe_latent_size, bias=config.mlp_bias
302+
)
303+
self.fc2_latent_proj = nn.Linear(
304+
config.moe_latent_size, config.hidden_size, bias=config.mlp_bias
305+
)
306+
else:
307+
self.fc1_latent_proj = nn.Identity()
308+
self.fc2_latent_proj = nn.Identity()
285309

286310
def forward(self, hidden_states: torch.Tensor):
287311
residuals = hidden_states

0 commit comments

Comments
 (0)