@@ -147,11 +147,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
147147 shared_output = self .shared_experts (hidden_states )
148148 # router_logits: (num_tokens, n_experts)
149149 router_logits , _ = self .gate (hidden_states )
150- final_hidden_states = self .experts (
151- hidden_states = hidden_states ,
152- router_logits = router_logits ) * self .routed_scaling_factor
150+ final_hidden_states = self .experts (hidden_states = hidden_states ,
151+ router_logits = router_logits )
153152 if shared_output is not None :
154- final_hidden_states = final_hidden_states + shared_output
153+ final_hidden_states = final_hidden_states + shared_output \
154+ * (1. / self .routed_scaling_factor )
155155 if self .tp_size > 1 :
156156 final_hidden_states = tensor_model_parallel_all_reduce (
157157 final_hidden_states )
@@ -375,6 +375,7 @@ def __init__(
375375 eps = config .rms_norm_eps )
376376 self .post_attention_layernorm = RMSNorm (config .hidden_size ,
377377 eps = config .rms_norm_eps )
378+ self .routed_scaling_factor = config .routed_scaling_factor
378379
379380 def forward (
380381 self ,
@@ -399,9 +400,14 @@ def forward(
399400 )
400401
401402 # Fully Connected
403+ if isinstance (self .mlp , DeepseekV2MoE ):
404+ hidden_states *= 1. / self .mlp .routed_scaling_factor
402405 hidden_states , residual = self .post_attention_layernorm (
403406 hidden_states , residual )
404407 hidden_states = self .mlp (hidden_states )
408+ if isinstance (self .mlp , DeepseekV2MLP ):
409+ hidden_states *= 1. / self .routed_scaling_factor
410+ residual *= 1. / self .routed_scaling_factor
405411 return hidden_states , residual
406412
407413
0 commit comments