@@ -150,11 +150,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
150150 shared_output = self .shared_experts (hidden_states )
151151 # router_logits: (num_tokens, n_experts)
152152 router_logits , _ = self .gate (hidden_states )
153- final_hidden_states = self .experts (hidden_states = hidden_states ,
154- router_logits = router_logits )
153+ if hidden_states .dtype != torch .float16 :
154+ final_hidden_states = self .experts (
155+ hidden_states = hidden_states ,
156+ router_logits = router_logits ) * self .routed_scaling_factor
157+ else :
158+ final_hidden_states = self .experts (hidden_states = hidden_states ,
159+ router_logits = router_logits )
155160 if shared_output is not None :
156- final_hidden_states = final_hidden_states + shared_output \
157- * (1. / self .routed_scaling_factor )
161+ if hidden_states .dtype != torch .float16 :
162+ final_hidden_states = final_hidden_states + shared_output
163+ else :
164+ final_hidden_states = final_hidden_states + shared_output \
165+ * (1. / self .routed_scaling_factor )
158166 if self .tp_size > 1 :
159167 final_hidden_states = tensor_model_parallel_all_reduce (
160168 final_hidden_states )
@@ -557,12 +565,14 @@ def forward(
557565 )
558566
559567 # Fully Connected
560- if isinstance (self .mlp , DeepseekV2MoE ):
568+ if isinstance (self .mlp , DeepseekV2MoE ) and \
569+ hidden_states .dtype == torch .float16 :
561570 hidden_states *= 1. / self .mlp .routed_scaling_factor
562571 hidden_states , residual = self .post_attention_layernorm (
563572 hidden_states , residual )
564573 hidden_states = self .mlp (hidden_states )
565- if isinstance (self .mlp , DeepseekV2MLP ):
574+ if isinstance (self .mlp , DeepseekV2MLP ) and \
575+ hidden_states .dtype == torch .float16 :
566576 hidden_states *= 1. / self .routed_scaling_factor
567577 residual *= 1. / self .routed_scaling_factor
568578 return hidden_states , residual
0 commit comments