Skip to content

Commit 95f5008

Browse files
Porting DeeSeek v2/r1 PRs (vllm-project#1756)
## Essential Elements of an Effective PR Description Checklist - [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". - [ ] The test plan, such as providing test command. - [ ] The test results, such as pasting the results comparison before and after, or e2e results ## Porting List 1. HabanaAI#1402 2. HabanaAI#1504 3. HabanaAI#1404 <!--- pyml disable-next-line no-emphasis-as-heading -->
1 parent fd41376 commit 95f5008

File tree

4 files changed

+74
-18
lines changed

4 files changed

+74
-18
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,9 +1171,9 @@ def get_scales(layer: LinearBase) -> torch.Tensor:
11711171
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
11721172

11731173
# Convert from (L, N, V) to (N, L, V)
1174-
self.W_UV = W_UV.transpose(0, 1)
1174+
self.W_UV = W_UV.transpose(0, 1).contiguous()
11751175
# Convert from (L, N, P) to (N, P, L)
1176-
self.W_UK_T = W_UK.permute(1, 2, 0)
1176+
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
11771177

11781178
def _compute_prefill_context(
11791179
self,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -944,20 +944,32 @@ def grouped_topk(
944944
# scores for expert selection but original scores for routing weights
945945
original_scores = scores
946946
scores = scores + e_score_correction_bias.unsqueeze(0)
947-
group_scores = (scores.view(num_token, num_expert_group,
948-
-1).topk(2, dim=-1)[0].sum(dim=-1))
947+
948+
scores_tmp = scores.clone().reshape(num_token, num_expert_group, -1)
949+
top1_val, top1_idx = torch.max(scores_tmp, dim=-1)
950+
scores_tmp.scatter_(-1, top1_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
951+
group_scores, top2_idx = torch.max(scores_tmp, dim=-1)
952+
group_scores.add_(top1_val)
949953
else:
950954
group_scores = scores.view(num_token, num_expert_group,
951955
-1).max(dim=-1).values # [n, n_group]
952-
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
953-
sorted=False)[1] # [n, top_k_group]
954-
group_mask = torch.zeros_like(group_scores) # [n, n_group]
955-
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
956-
score_mask = group_mask.unsqueeze(-1).expand(
957-
num_token, num_expert_group,
958-
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
959-
tmp_scores = scores.masked_fill(~score_mask.bool(),
960-
float("-inf")) # [n, e]
956+
957+
if num_token > 1024:
958+
group_mask = torch.zeros_like(group_scores)
959+
for i in range(topk_group):
960+
_, group_idx = torch.max(group_scores, dim=-1)
961+
group_mask.scatter_(1, group_idx.unsqueeze(-1), 1)
962+
if i < topk_group - 1:
963+
group_scores.scatter_(1, group_idx.unsqueeze(-1), torch.finfo(scores.dtype).min)
964+
else:
965+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
966+
sorted=False)[1] # [n, top_k_group]
967+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
968+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
969+
970+
tmp_scores = scores.reshape(num_token, num_expert_group, -1) + \
971+
((1 - group_mask) * torch.finfo(scores.dtype).min).unsqueeze(-1)
972+
tmp_scores = tmp_scores.reshape(num_token, -1)
961973

962974
if e_score_correction_bias is not None:
963975
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
@@ -972,7 +984,7 @@ def grouped_topk(
972984
if renormalize:
973985
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
974986

975-
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
987+
return topk_weights.to(torch.bfloat16), topk_ids.to(torch.int32)
976988

977989

978990
def get_config_dtype_str(

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
921921
cache = torch.cat((cos, sin), dim=-1)
922922
return cache
923923

924-
def forward(
924+
def forward_native(
925925
self,
926926
positions: torch.Tensor,
927927
query: torch.Tensor,
@@ -963,6 +963,48 @@ def forward(
963963
key = key_rot
964964
return query, key
965965

966+
def forward_hpu(
967+
self,
968+
positions: torch.Tensor,
969+
query: torch.Tensor,
970+
key: torch.Tensor,
971+
offsets: Optional[torch.Tensor] = None,
972+
) -> tuple[torch.Tensor, torch.Tensor]:
973+
from habana_frameworks.torch.hpex.kernels import (
974+
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
975+
"""PyTorch-native implementation equivalent to forward()."""
976+
query_rot = query[..., :self.rotary_dim]
977+
key_rot = key[..., :self.rotary_dim]
978+
if self.rotary_dim < self.head_size:
979+
query_pass = query[..., self.rotary_dim:]
980+
key_pass = key[..., self.rotary_dim:]
981+
982+
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
983+
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
984+
if offsets is not None else positions]
985+
cos, sin = cos_sin.chunk(2, dim=-1)
986+
rope_mode: RotaryPosEmbeddingMode
987+
if self.is_neox_style:
988+
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
989+
# NOTE(woosuk): Here we assume that the positions tensor has the
990+
# shape [batch_size, seq_len].
991+
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
992+
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
993+
else:
994+
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
995+
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
996+
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
997+
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
998+
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
999+
1000+
if self.rotary_dim < self.head_size:
1001+
query = torch.cat((query_rot, query_pass), dim=-1)
1002+
key = torch.cat((key_rot, key_pass), dim=-1)
1003+
else:
1004+
query = query_rot
1005+
key = key_rot
1006+
return query, key
1007+
9661008

9671009
class Llama3RotaryEmbedding(RotaryEmbedding):
9681010

vllm/model_executor/models/deepseek_v2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
156156
input_shape = hidden_states.shape
157157
hidden_dim = input_shape[-1]
158158
hidden_states = hidden_states.view(-1, hidden_dim)
159-
if self.n_shared_experts is not None:
160-
shared_output = self.shared_experts(hidden_states)
161159
# router_logits: (num_tokens, n_experts)
162160
router_logits, _ = self.gate(hidden_states)
163161

@@ -170,9 +168,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
170168
# See DeepseekV2DecoderLayer for more details.
171169
final_hidden_states = self.experts(hidden_states=hidden_states,
172170
router_logits=router_logits)
171+
172+
if self.n_shared_experts is not None:
173+
shared_output = self.shared_experts(hidden_states)
174+
173175
if shared_output is not None:
174176
if hidden_states.dtype != torch.float16:
175-
final_hidden_states = final_hidden_states + shared_output
177+
final_hidden_states.add_(shared_output)
176178
else:
177179
# Fix FP16 overflow
178180
# See DeepseekV2DecoderLayer for more details.

0 commit comments

Comments
 (0)