@@ -250,14 +250,23 @@ def forward(self, hidden_states):
250250
251251# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
252252class NemotronHMLP (nn .Module ):
253- def __init__ (self , config , layer_idx : int , intermediate_size : Optional [int ] = None ):
253+ def __init__ (
254+ self ,
255+ config ,
256+ layer_idx : int ,
257+ intermediate_size : Optional [int ] = None ,
258+ is_expert : bool = False ,
259+ ):
254260 super ().__init__ ()
255261 self .config = config
256262 self .layer_idx = layer_idx
257263 self .hidden_size = config .hidden_size
258264 self .intermediate_size = intermediate_size or config .intermediate_size
259- self .up_proj = nn .Linear (self .hidden_size , self .intermediate_size , bias = config .mlp_bias )
260- self .down_proj = nn .Linear (self .intermediate_size , self .hidden_size , bias = config .mlp_bias )
265+ # Use latent size for expert MLPs if provided by config (required for SuperV3)
266+ use_latent_size = (getattr (self .config , "moe_latent_size" , None ) is not None ) and is_expert
267+ input_size = self .config .moe_latent_size if use_latent_size else self .hidden_size
268+ self .up_proj = nn .Linear (input_size , self .intermediate_size , bias = config .mlp_bias )
269+ self .down_proj = nn .Linear (self .intermediate_size , input_size , bias = config .mlp_bias )
261270 self .act_fn = ACT2FN [config .mlp_hidden_act ]
262271
263272 def forward (self , x ):
@@ -271,7 +280,10 @@ def __init__(self, config, layer_idx: Optional[int] = None):
271280 self .experts = nn .ModuleList (
272281 [
273282 NemotronHMLP (
274- config , intermediate_size = config .moe_intermediate_size , layer_idx = layer_idx
283+ config ,
284+ layer_idx = layer_idx ,
285+ intermediate_size = config .moe_intermediate_size ,
286+ is_expert = True ,
275287 )
276288 for _ in range (config .n_routed_experts )
277289 ]
@@ -281,7 +293,19 @@ def __init__(self, config, layer_idx: Optional[int] = None):
281293 config = config ,
282294 intermediate_size = config .moe_shared_expert_intermediate_size ,
283295 layer_idx = layer_idx ,
296+ is_expert = False ,
284297 )
298+ # Add latent projections when using latent MoE (required for SuperV3)
299+ if getattr (config , "moe_latent_size" , None ) is not None :
300+ self .fc1_latent_proj = nn .Linear (
301+ config .hidden_size , config .moe_latent_size , bias = config .mlp_bias
302+ )
303+ self .fc2_latent_proj = nn .Linear (
304+ config .moe_latent_size , config .hidden_size , bias = config .mlp_bias
305+ )
306+ else :
307+ self .fc1_latent_proj = nn .Identity ()
308+ self .fc2_latent_proj = nn .Identity ()
285309
286310 def forward (self , hidden_states : torch .Tensor ):
287311 residuals = hidden_states
0 commit comments