Skip to content

Commit 9bbb40b

Browse files
authored
fix prefix cache of ds with fa3 (#861)
1 parent 3a96e3c commit 9bbb40b

File tree

3 files changed

+114
-19
lines changed

3 files changed

+114
-19
lines changed

lightllm/models/deepseek2/flashattention_infer_struct.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,13 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2828
torch.cumsum(self.b_seq_len - self.b_ready_cache_len, dim=0, dtype=torch.int32), (1, 0)
2929
)
3030
self.cu_seqlens_k = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
31-
self.page_table = torch.empty((self.batch_size, self.max_seq_len), dtype=torch.int32).to(input_ids.device)
32-
self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len])
31+
self.has_prefix_kv = self.b_ready_cache_len_numpy.any()
32+
if self.has_prefix_kv:
33+
self.cu_seqlens_prefix_k = torch.nn.functional.pad(
34+
torch.cumsum(self.b_ready_cache_len, dim=0, dtype=torch.int32), (1, 0)
35+
)
36+
self.prefix_k_max_len = self.b_ready_cache_len_numpy.max()
37+
self.prefix_total_token_num = self.b_ready_cache_len_numpy.sum()
3338
else:
3439
# Meta information of flashattention for decoding
3540
self.cu_seqlens_q = torch.arange(0, self.batch_size + 1, dtype=torch.int32, device=input_ids.device)

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 106 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
try:
3737
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
38+
from sgl_kernel import merge_state_v2
3839
except:
3940
logger.warning("sgl_kernel is not installed, or the installed version does not support fa3!")
4041

@@ -248,31 +249,38 @@ def _tpsp_get_o(
248249
return o_tensor
249250

250251
def _decompress_kv(
251-
self, kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, is_fp8
252+
self,
253+
kv,
254+
infer_state: Deepseek2InferStateInfo,
255+
layer_weight: Deepseek2TransformerLayerWeight,
256+
is_fp8,
257+
total_token_num,
258+
b_seq_len,
259+
max_seq_len,
260+
b_kv_start_loc,
261+
skip_sample=False,
252262
):
253-
if infer_state.use_dynamic_prompt_cache:
263+
if infer_state.use_dynamic_prompt_cache and not skip_sample:
254264
if is_fp8:
255265
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
256266
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
257-
k_scale = self.alloc_tensor([infer_state.total_token_num, 1], dtype=kv_scale.dtype)
267+
k_scale = self.alloc_tensor([total_token_num, 1], dtype=kv_scale.dtype)
258268
else:
259269
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
260270
kv_scale = None
261271
k_scale = None
262272

263-
compressed_kv = self.alloc_tensor(
264-
[infer_state.total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype
265-
)
266-
k_rope = self.alloc_tensor([infer_state.total_token_num, 1, self.qk_rope_head_dim], dtype=kv.dtype)
273+
compressed_kv = self.alloc_tensor([total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype)
274+
k_rope = self.alloc_tensor([total_token_num, 1, self.qk_rope_head_dim], dtype=kv.dtype)
267275
sample_kv(
268276
kv,
269277
compressed_kv,
270278
k_rope,
271279
infer_state.b_req_idx,
272-
infer_state.max_value_in_b_seq_len,
273-
infer_state.b_seq_len,
280+
max_seq_len,
281+
b_seq_len,
274282
infer_state.req_manager.req_to_token_indexs,
275-
infer_state.b_kv_start_loc,
283+
b_kv_start_loc,
276284
kv_scale,
277285
k_scale,
278286
)
@@ -294,6 +302,8 @@ def _decompress_kv(
294302
k_nope, v = torch.split(kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
295303
return k_nope, k_rope, v
296304

305+
# Adapted from:
306+
# https://github.com/sgl-project/sglang/blob/c998d04b46920f06d945fbef9023884a768723fc/python/sglang/srt/models/deepseek_v2.py#L962
297307
def _context_attention_flashattention_kernel_with_CC(
298308
self,
299309
q: torch.Tensor,
@@ -302,9 +312,19 @@ def _context_attention_flashattention_kernel_with_CC(
302312
layer_weight: Deepseek2TransformerLayerWeight,
303313
out=None,
304314
) -> torch.Tensor:
305-
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
315+
k_nope, k_rope, v = self._decompress_kv(
316+
kv,
317+
infer_state,
318+
layer_weight,
319+
False,
320+
infer_state.total_token_num,
321+
infer_state.b_seq_len,
322+
infer_state.max_value_in_b_seq_len,
323+
infer_state.b_kv_start_loc,
324+
skip_sample=True,
325+
)
306326
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
307-
o_tensor = flash_attn_varlen_func(
327+
o_tensor, lse, *rest = flash_attn_varlen_func(
308328
q=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
309329
k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
310330
v=v.view(-1, self.tp_v_head_num_, self.v_head_dim),
@@ -314,8 +334,41 @@ def _context_attention_flashattention_kernel_with_CC(
314334
max_seqlen_k=infer_state.max_seq_len,
315335
softmax_scale=self.softmax_scale,
316336
causal=True,
317-
return_softmax_lse=False,
337+
return_softmax_lse=True,
318338
)
339+
if infer_state.has_prefix_kv:
340+
k_nope, k_rope, v = self._decompress_kv(
341+
kv,
342+
infer_state,
343+
layer_weight,
344+
False,
345+
infer_state.prefix_total_token_num,
346+
infer_state.b_ready_cache_len,
347+
infer_state.prefix_k_max_len,
348+
infer_state.cu_seqlens_prefix_k,
349+
)
350+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
351+
prefix_output, prefix_lse, *rest = flash_attn_varlen_func(
352+
q=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
353+
k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
354+
v=v.view(-1, self.tp_v_head_num_, self.v_head_dim),
355+
cu_seqlens_q=infer_state.cu_seqlens_q,
356+
cu_seqlens_k=infer_state.cu_seqlens_prefix_k,
357+
max_seqlen_q=infer_state.q_max_seq_len,
358+
max_seqlen_k=infer_state.prefix_k_max_len,
359+
softmax_scale=self.softmax_scale,
360+
causal=False,
361+
return_softmax_lse=True,
362+
)
363+
lse = torch.transpose(lse, 0, 1).contiguous()
364+
prefix_lse = torch.transpose(prefix_lse, 0, 1).contiguous()
365+
tmp_output = (
366+
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype)
367+
if out is None
368+
else out
369+
)
370+
tmp_lse = torch.empty_like(lse)
371+
merge_state_v2(prefix_output, prefix_lse, o_tensor, lse, tmp_output, tmp_lse)
319372
return o_tensor
320373

321374
def _context_attention_flashinfer_kernel_with_CC(
@@ -326,7 +379,16 @@ def _context_attention_flashinfer_kernel_with_CC(
326379
layer_weight: Deepseek2TransformerLayerWeight,
327380
out=None,
328381
) -> torch.Tensor:
329-
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
382+
k_nope, k_rope, v = self._decompress_kv(
383+
kv,
384+
infer_state,
385+
layer_weight,
386+
False,
387+
infer_state.total_token_num,
388+
infer_state.b_seq_len,
389+
infer_state.max_value_in_b_seq_len,
390+
infer_state.b_kv_start_loc,
391+
)
330392
o_tensor = (
331393
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
332394
)
@@ -342,7 +404,16 @@ def _context_attention_flashinfer_kernel_with_CC_fp8(
342404
layer_weight: Deepseek2TransformerLayerWeight,
343405
out=None,
344406
) -> torch.Tensor:
345-
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True)
407+
k_nope, k_rope, v = self._decompress_kv(
408+
kv,
409+
infer_state,
410+
layer_weight,
411+
True,
412+
infer_state.total_token_num,
413+
infer_state.b_seq_len,
414+
infer_state.max_value_in_b_seq_len,
415+
infer_state.b_kv_start_loc,
416+
)
346417
o_tensor = (
347418
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
348419
)
@@ -358,7 +429,16 @@ def _context_attention_kernel_with_CC(
358429
layer_weight: Deepseek2TransformerLayerWeight,
359430
out=None,
360431
) -> torch.Tensor:
361-
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
432+
k_nope, k_rope, v = self._decompress_kv(
433+
kv,
434+
infer_state,
435+
layer_weight,
436+
False,
437+
infer_state.total_token_num,
438+
infer_state.b_seq_len,
439+
infer_state.max_value_in_b_seq_len,
440+
infer_state.b_kv_start_loc,
441+
)
362442
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
363443
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
364444
context_attention_fwd_with_v(
@@ -385,7 +465,16 @@ def _context_attention_kernel_with_CC_fp8(
385465
layer_weight: Deepseek2TransformerLayerWeight,
386466
out=None,
387467
) -> torch.Tensor:
388-
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True)
468+
k_nope, k_rope, v = self._decompress_kv(
469+
kv,
470+
infer_state,
471+
layer_weight,
472+
True,
473+
infer_state.total_token_num,
474+
infer_state.b_seq_len,
475+
infer_state.max_value_in_b_seq_len,
476+
infer_state.b_kv_start_loc,
477+
)
389478
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
390479
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
391480
context_attention_fwd_with_v(

lightllm/models/llama/infer_struct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
1515
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
1616
self.max_seq_len = b_seq_len_numpy.max()
1717
b_ready_cache_len_numpy = self.b_ready_cache_len.cpu().numpy()
18+
self.b_ready_cache_len_numpy = b_ready_cache_len_numpy
1819
self.q_max_seq_len = (b_seq_len_numpy - b_ready_cache_len_numpy).max()
1920
position_ids = torch.from_numpy(
2021
np.concatenate(

0 commit comments

Comments
 (0)