Skip to content

Commit b06c154

Browse files
committed
DS V2V3 fix for same file
1 parent e2dc610 commit b06c154

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

vllm/model_executor/models/deepseek_v2.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
150150
shared_output = self.shared_experts(hidden_states)
151151
# router_logits: (num_tokens, n_experts)
152152
router_logits, _ = self.gate(hidden_states)
153-
final_hidden_states = self.experts(hidden_states=hidden_states,
154-
router_logits=router_logits)
153+
if hidden_states.dtype != torch.float16:
154+
final_hidden_states = self.experts(
155+
hidden_states=hidden_states,
156+
router_logits=router_logits) * self.routed_scaling_factor
157+
else:
158+
final_hidden_states = self.experts(hidden_states=hidden_states,
159+
router_logits=router_logits)
155160
if shared_output is not None:
156-
final_hidden_states = final_hidden_states + shared_output \
157-
* (1. / self.routed_scaling_factor)
161+
if hidden_states.dtype != torch.float16:
162+
final_hidden_states = final_hidden_states + shared_output
163+
else:
164+
final_hidden_states = final_hidden_states + shared_output \
165+
* (1. / self.routed_scaling_factor)
158166
if self.tp_size > 1:
159167
final_hidden_states = tensor_model_parallel_all_reduce(
160168
final_hidden_states)
@@ -557,12 +565,14 @@ def forward(
557565
)
558566

559567
# Fully Connected
560-
if isinstance(self.mlp, DeepseekV2MoE):
568+
if isinstance(self.mlp, DeepseekV2MoE) and \
569+
hidden_states.dtype == torch.float16:
561570
hidden_states *= 1. / self.mlp.routed_scaling_factor
562571
hidden_states, residual = self.post_attention_layernorm(
563572
hidden_states, residual)
564573
hidden_states = self.mlp(hidden_states)
565-
if isinstance(self.mlp, DeepseekV2MLP):
574+
if isinstance(self.mlp, DeepseekV2MLP) and \
575+
hidden_states.dtype == torch.float16:
566576
hidden_states *= 1. / self.routed_scaling_factor
567577
residual *= 1. / self.routed_scaling_factor
568578
return hidden_states, residual

0 commit comments

Comments
 (0)