Skip to content

Commit f5cbcc1

Browse files
author
niushengxiao
committed
opt: refactor some code for acc
1 parent 8bd8214 commit f5cbcc1

File tree

3 files changed

+41
-22
lines changed

3 files changed

+41
-22
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
4+
5+
6+
class Deepseek2InferStateInfo(LlamaInferStateInfo):
7+
def __init__(self):
8+
super().__init__()
9+
self.kv_starts = None
10+
11+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
12+
super().init_some_extra_state(model, input_ids)
13+
self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
14+
return

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
2020
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
2121
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd
22-
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
22+
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
2323
from functools import partial
2424
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2525
import os
@@ -83,7 +83,7 @@ def _get_qkv(
8383
self,
8484
input: torch.Tensor,
8585
cache_kv,
86-
infer_state: LlamaInferStateInfo,
86+
infer_state: Deepseek2InferStateInfo,
8787
layer_weight: Deepseek2TransformerLayerWeight,
8888
) -> torch.Tensor:
8989
input = input.view(-1, self.embed_dim_)
@@ -133,7 +133,7 @@ def _get_qkv(
133133
return (q_nope, q_rope), cache_kv
134134

135135
def _get_o(
136-
self, input, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
136+
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
137137
) -> torch.Tensor:
138138
if not self.disable_vo_absorb:
139139
input = input.view(-1, self.tp_q_head_num_ * self.kv_lora_rank)
@@ -145,7 +145,7 @@ def _get_o(
145145
return o_tensor
146146

147147
def _CC_method(
148-
self, q, compressed_kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
148+
self, q, compressed_kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
149149
):
150150
num_local_heads = self.num_heads
151151
num_local_kv_heads = self.num_kv_heads
@@ -176,7 +176,7 @@ def _CC_method(
176176
return self._context_attention_kernel_with_v(q, [k_nope, k_pe], v, infer_state, layer_weight)
177177

178178
def _ACC_method(
179-
self, q, compressed_kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
179+
self, q, compressed_kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
180180
):
181181
q_nope, q_rope = q
182182
num_local_heads = self.num_heads
@@ -185,22 +185,21 @@ def _ACC_method(
185185
num_local_heads //= self.world_size_
186186
num_local_kv_heads //= self.world_size_
187187
# ACC
188-
q_nope = layer_weight.k_b_proj_.weight.bmm(
188+
q_nope = layer_weight.k_b_proj_.bmm(
189189
q_nope.transpose(0, 1),
190190
).transpose(0, 1)
191191
if self.enable_opt_decoding_mha:
192192
import lightllm_ppl_mla
193193

194194
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype)
195-
kvstarts = torch.cat(
196-
[infer_state.b_start_loc, infer_state.b_start_loc[-1:] + infer_state.b_seq_len[-1:]], dim=0
197-
)
195+
q = torch.cat([q_nope, q_rope], dim=-1)
198196
lightllm_ppl_mla.decode_mla(
199197
o_tensor,
200198
q,
201-
compressed_kv[: infer_state.mem_end, :, :],
202-
infer_state.b_start_loc,
203-
kvstarts,
199+
compressed_kv,
200+
infer_state.req_manager.req_to_token_indexs,
201+
infer_state.kv_starts,
202+
infer_state.b_req_idx,
204203
self.softmax_scale,
205204
q.shape[-1],
206205
q_nope.shape[-1],
@@ -214,20 +213,20 @@ def _ACC_method(
214213
return vo
215214

216215
def _context_attention_kernel(
217-
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
216+
self, q, kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
218217
) -> torch.Tensor:
219218
if self.mla_type == "MIX":
220219
return self._context_attention_kernel_with_CC(q, kv, infer_state, layer_weight, out)
221220
else:
222221
return self._context_attention_kernel_origin(q, kv, infer_state, layer_weight, out)
223222

224223
def _context_attention_kernel_with_CC(
225-
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
224+
self, q, kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
226225
) -> torch.Tensor:
227226
return self._CC_method(q, kv, infer_state, layer_weight)
228227

229228
def _context_attention_kernel_with_v(
230-
self, q: Tuple[torch.Tensor, torch.Tensor], k, v, infer_state: LlamaInferStateInfo, layer_weight, out=None
229+
self, q: Tuple[torch.Tensor, torch.Tensor], k, v, infer_state: Deepseek2InferStateInfo, layer_weight, out=None
231230
) -> torch.Tensor:
232231
q_nope, q_rope = q
233232
k_nope, k_rope = k
@@ -267,7 +266,7 @@ def _context_attention_kernel_with_v(
267266
return o_tensor
268267

269268
def _context_attention_kernel_origin(
270-
self, q: Tuple[torch.Tensor, torch.Tensor], kv, infer_state: LlamaInferStateInfo, layer_weight, out=None
269+
self, q: Tuple[torch.Tensor, torch.Tensor], kv, infer_state: Deepseek2InferStateInfo, layer_weight, out=None
271270
) -> torch.Tensor:
272271
q_nope, q_rope = q
273272
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
@@ -304,20 +303,22 @@ def _context_attention_kernel_origin(
304303
q_rope = None
305304
return o_tensor
306305

307-
def _token_gqa_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
306+
def _token_gqa_decode_attention_flashdecoding(
307+
self, q, infer_state: Deepseek2InferStateInfo, layer_weight, out=None
308+
):
308309
if self.mla_type == "MIX":
309310
return self._token_gqa_decode_attention_flashdecoding_with_ACC(q, infer_state, layer_weight, out)
310311
else:
311312
return self._token_gqa_decode_attention_flashdecoding_origin(q, infer_state, layer_weight, out)
312313

313314
def _token_gqa_decode_attention_flashdecoding_with_ACC(
314-
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
315+
self, q, infer_state: Deepseek2InferStateInfo, layer_weight, out=None
315316
):
316-
# compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_][: infer_state.mem_end, :, :]
317-
return self._ACC_method(q, None, infer_state, layer_weight)
317+
compressed_kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
318+
return self._ACC_method(q, compressed_kv, infer_state, layer_weight)
318319

319320
def _token_gqa_decode_attention_flashdecoding_origin(
320-
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
321+
self, q, infer_state: Deepseek2InferStateInfo, layer_weight, out=None
321322
):
322323
q_nope, q_rope = q
323324
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank]
@@ -347,7 +348,7 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
347348
return
348349

349350
def _moe_ffn(
350-
self, input, infer_state: LlamaInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
351+
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
351352
) -> torch.Tensor:
352353
hidden_states = input.view(-1, self.embed_dim_)
353354
num_tokens, hidden_dim = hidden_states.shape

lightllm/models/deepseek2/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer
55
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
6+
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
67
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
78

89
from lightllm.models.llama.model import LlamaTpPartModel
@@ -20,6 +21,9 @@ class Deepseek2TpPartModel(LlamaTpPartModel):
2021
# infer class
2122
transformer_layer_infer_class = Deepseek2TransformerLayerInfer
2223

24+
# infer state class
25+
infer_state_class = Deepseek2InferStateInfo
26+
2327
def __init__(self, kvargs):
2428
self.disable_qk_absorb = os.getenv("DISABLE_QK_ABSORB", "False").upper() in ["ON", "TRUE", "1"]
2529
self.disable_vo_absorb = os.getenv("DISABLE_VO_ABSORB", "False").upper() in ["ON", "TRUE", "1"]

0 commit comments

Comments
 (0)