Skip to content

Commit 269631d

Browse files
authored
feat: add _context_attention_kernel_with_CC in deepseek2 (#693)
1 parent 2c64be6 commit 269631d

File tree

3 files changed

+198
-246
lines changed

3 files changed

+198
-246
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
context_attention_fwd,
1111
context_attention_fwd_no_prompt_cache,
1212
)
13+
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v
14+
from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv
1315

1416
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
1517
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
@@ -54,6 +56,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
5456
if mscale_all_dim:
5557
mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim)
5658
self.softmax_scale = self.softmax_scale * mscale * mscale
59+
self.enable_cc_method = os.getenv("ENABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
5760
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
5861
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
5962
if self.enable_dp:
@@ -65,7 +68,14 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6568
return
6669

6770
def _bind_attention(self):
68-
self._context_attention_kernel = partial(Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self)
71+
if self.enable_cc_method:
72+
self._context_attention_kernel = partial(
73+
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
74+
)
75+
else:
76+
self._context_attention_kernel = partial(
77+
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
78+
)
6979
self._token_attention_kernel = partial(
7080
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
7181
)
@@ -123,6 +133,65 @@ def _get_o(
123133
o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim))
124134
return o_tensor
125135

136+
def _context_attention_kernel_with_CC(
137+
self,
138+
q: torch.Tensor,
139+
kv,
140+
infer_state: Deepseek2InferStateInfo,
141+
layer_weight: Deepseek2TransformerLayerWeight,
142+
out=None,
143+
) -> torch.Tensor:
144+
if infer_state.use_dynamic_prompt_cache:
145+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
146+
compressed_kv = self.alloc_tensor(
147+
[infer_state.total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype
148+
)
149+
k_rope = self.alloc_tensor([infer_state.total_token_num, 1, self.qk_rope_head_dim], dtype=kv.dtype)
150+
sample_kv(
151+
kv,
152+
compressed_kv,
153+
k_rope,
154+
infer_state.b_req_idx,
155+
infer_state.b_seq_len,
156+
infer_state.req_manager.req_to_token_indexs,
157+
)
158+
else:
159+
compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r)
160+
kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1
161+
)
162+
163+
# CC
164+
k_nope = self.alloc_tensor(
165+
[compressed_kv.shape[0], q.shape[1], self.qk_nope_head_dim],
166+
dtype=compressed_kv.dtype,
167+
)
168+
v = self.alloc_tensor(
169+
k_nope.shape,
170+
dtype=compressed_kv.dtype,
171+
)
172+
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank)
173+
wk = layer_weight.k_b_proj_.weight.view(-1, layer_weight.kv_lora_rank)
174+
wv = layer_weight.v_b_proj_.weight.transpose(1, 2).view(-1, layer_weight.kv_lora_rank)
175+
torch.mm(compressed_kv, wk.transpose(0, 1), out=k_nope.reshape(compressed_kv.shape[0], -1))
176+
torch.mm(compressed_kv, wv.transpose(0, 1), out=v.reshape(compressed_kv.shape[0], -1))
177+
178+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
179+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
180+
context_attention_fwd_with_v(
181+
q_nope,
182+
q_rope,
183+
k_nope,
184+
k_rope,
185+
v,
186+
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
187+
infer_state.b_start_loc,
188+
infer_state.b_seq_len,
189+
infer_state.b_ready_cache_len,
190+
infer_state.max_len_in_batch,
191+
self.softmax_scale,
192+
)
193+
return o_tensor
194+
126195
def _context_attention_kernel_origin(
127196
self,
128197
q: torch.Tensor,

0 commit comments

Comments
 (0)