3535
3636try :
3737 from sgl_kernel .flash_attn import flash_attn_varlen_func , flash_attn_with_kvcache
38+ from sgl_kernel import merge_state_v2
3839except :
3940 logger .warning ("sgl_kernel is not installed, or the installed version does not support fa3!" )
4041
@@ -248,31 +249,38 @@ def _tpsp_get_o(
248249 return o_tensor
249250
250251 def _decompress_kv (
251- self , kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , is_fp8
252+ self ,
253+ kv ,
254+ infer_state : Deepseek2InferStateInfo ,
255+ layer_weight : Deepseek2TransformerLayerWeight ,
256+ is_fp8 ,
257+ total_token_num ,
258+ b_seq_len ,
259+ max_seq_len ,
260+ b_kv_start_loc ,
261+ skip_sample = False ,
252262 ):
253- if infer_state .use_dynamic_prompt_cache :
263+ if infer_state .use_dynamic_prompt_cache and not skip_sample :
254264 if is_fp8 :
255265 kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, :- 2 ].view (torch .float8_e4m3fn )
256266 kv_scale = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (torch .bfloat16 )
257- k_scale = self .alloc_tensor ([infer_state . total_token_num , 1 ], dtype = kv_scale .dtype )
267+ k_scale = self .alloc_tensor ([total_token_num , 1 ], dtype = kv_scale .dtype )
258268 else :
259269 kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
260270 kv_scale = None
261271 k_scale = None
262272
263- compressed_kv = self .alloc_tensor (
264- [infer_state .total_token_num , 1 , layer_weight .kv_lora_rank ], dtype = kv .dtype
265- )
266- k_rope = self .alloc_tensor ([infer_state .total_token_num , 1 , self .qk_rope_head_dim ], dtype = kv .dtype )
273+ compressed_kv = self .alloc_tensor ([total_token_num , 1 , layer_weight .kv_lora_rank ], dtype = kv .dtype )
274+ k_rope = self .alloc_tensor ([total_token_num , 1 , self .qk_rope_head_dim ], dtype = kv .dtype )
267275 sample_kv (
268276 kv ,
269277 compressed_kv ,
270278 k_rope ,
271279 infer_state .b_req_idx ,
272- infer_state . max_value_in_b_seq_len ,
273- infer_state . b_seq_len ,
280+ max_seq_len ,
281+ b_seq_len ,
274282 infer_state .req_manager .req_to_token_indexs ,
275- infer_state . b_kv_start_loc ,
283+ b_kv_start_loc ,
276284 kv_scale ,
277285 k_scale ,
278286 )
@@ -294,6 +302,8 @@ def _decompress_kv(
294302 k_nope , v = torch .split (kv_nope , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
295303 return k_nope , k_rope , v
296304
305+ # Adapted from:
306+ # https://github.com/sgl-project/sglang/blob/c998d04b46920f06d945fbef9023884a768723fc/python/sglang/srt/models/deepseek_v2.py#L962
297307 def _context_attention_flashattention_kernel_with_CC (
298308 self ,
299309 q : torch .Tensor ,
@@ -302,9 +312,19 @@ def _context_attention_flashattention_kernel_with_CC(
302312 layer_weight : Deepseek2TransformerLayerWeight ,
303313 out = None ,
304314 ) -> torch .Tensor :
305- k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , False )
315+ k_nope , k_rope , v = self ._decompress_kv (
316+ kv ,
317+ infer_state ,
318+ layer_weight ,
319+ False ,
320+ infer_state .total_token_num ,
321+ infer_state .b_seq_len ,
322+ infer_state .max_value_in_b_seq_len ,
323+ infer_state .b_kv_start_loc ,
324+ skip_sample = True ,
325+ )
306326 k = torch .cat ([k_nope , torch .repeat_interleave (k_rope , self .tp_q_head_num_ , dim = - 2 )], dim = - 1 )
307- o_tensor = flash_attn_varlen_func (
327+ o_tensor , lse , * rest = flash_attn_varlen_func (
308328 q = q .view (- 1 , self .tp_q_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim ),
309329 k = k .view (- 1 , self .tp_k_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim ),
310330 v = v .view (- 1 , self .tp_v_head_num_ , self .v_head_dim ),
@@ -314,8 +334,41 @@ def _context_attention_flashattention_kernel_with_CC(
314334 max_seqlen_k = infer_state .max_seq_len ,
315335 softmax_scale = self .softmax_scale ,
316336 causal = True ,
317- return_softmax_lse = False ,
337+ return_softmax_lse = True ,
318338 )
339+ if infer_state .has_prefix_kv :
340+ k_nope , k_rope , v = self ._decompress_kv (
341+ kv ,
342+ infer_state ,
343+ layer_weight ,
344+ False ,
345+ infer_state .prefix_total_token_num ,
346+ infer_state .b_ready_cache_len ,
347+ infer_state .prefix_k_max_len ,
348+ infer_state .cu_seqlens_prefix_k ,
349+ )
350+ k = torch .cat ([k_nope , torch .repeat_interleave (k_rope , self .tp_q_head_num_ , dim = - 2 )], dim = - 1 )
351+ prefix_output , prefix_lse , * rest = flash_attn_varlen_func (
352+ q = q .view (- 1 , self .tp_q_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim ),
353+ k = k .view (- 1 , self .tp_k_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim ),
354+ v = v .view (- 1 , self .tp_v_head_num_ , self .v_head_dim ),
355+ cu_seqlens_q = infer_state .cu_seqlens_q ,
356+ cu_seqlens_k = infer_state .cu_seqlens_prefix_k ,
357+ max_seqlen_q = infer_state .q_max_seq_len ,
358+ max_seqlen_k = infer_state .prefix_k_max_len ,
359+ softmax_scale = self .softmax_scale ,
360+ causal = False ,
361+ return_softmax_lse = True ,
362+ )
363+ lse = torch .transpose (lse , 0 , 1 ).contiguous ()
364+ prefix_lse = torch .transpose (prefix_lse , 0 , 1 ).contiguous ()
365+ tmp_output = (
366+ self .alloc_tensor ((q .shape [0 ], q .shape [1 ], self .qk_nope_head_dim ), dtype = q .dtype )
367+ if out is None
368+ else out
369+ )
370+ tmp_lse = torch .empty_like (lse )
371+ merge_state_v2 (prefix_output , prefix_lse , o_tensor , lse , tmp_output , tmp_lse )
319372 return o_tensor
320373
321374 def _context_attention_flashinfer_kernel_with_CC (
@@ -326,7 +379,16 @@ def _context_attention_flashinfer_kernel_with_CC(
326379 layer_weight : Deepseek2TransformerLayerWeight ,
327380 out = None ,
328381 ) -> torch .Tensor :
329- k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , False )
382+ k_nope , k_rope , v = self ._decompress_kv (
383+ kv ,
384+ infer_state ,
385+ layer_weight ,
386+ False ,
387+ infer_state .total_token_num ,
388+ infer_state .b_seq_len ,
389+ infer_state .max_value_in_b_seq_len ,
390+ infer_state .b_kv_start_loc ,
391+ )
330392 o_tensor = (
331393 self .alloc_tensor ((q .shape [0 ], q .shape [1 ], self .qk_nope_head_dim ), dtype = q .dtype ) if out is None else out
332394 )
@@ -342,7 +404,16 @@ def _context_attention_flashinfer_kernel_with_CC_fp8(
342404 layer_weight : Deepseek2TransformerLayerWeight ,
343405 out = None ,
344406 ) -> torch .Tensor :
345- k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , True )
407+ k_nope , k_rope , v = self ._decompress_kv (
408+ kv ,
409+ infer_state ,
410+ layer_weight ,
411+ True ,
412+ infer_state .total_token_num ,
413+ infer_state .b_seq_len ,
414+ infer_state .max_value_in_b_seq_len ,
415+ infer_state .b_kv_start_loc ,
416+ )
346417 o_tensor = (
347418 self .alloc_tensor ((q .shape [0 ], q .shape [1 ], self .qk_nope_head_dim ), dtype = q .dtype ) if out is None else out
348419 )
@@ -358,7 +429,16 @@ def _context_attention_kernel_with_CC(
358429 layer_weight : Deepseek2TransformerLayerWeight ,
359430 out = None ,
360431 ) -> torch .Tensor :
361- k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , False )
432+ k_nope , k_rope , v = self ._decompress_kv (
433+ kv ,
434+ infer_state ,
435+ layer_weight ,
436+ False ,
437+ infer_state .total_token_num ,
438+ infer_state .b_seq_len ,
439+ infer_state .max_value_in_b_seq_len ,
440+ infer_state .b_kv_start_loc ,
441+ )
362442 q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
363443 o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
364444 context_attention_fwd_with_v (
@@ -385,7 +465,16 @@ def _context_attention_kernel_with_CC_fp8(
385465 layer_weight : Deepseek2TransformerLayerWeight ,
386466 out = None ,
387467 ) -> torch .Tensor :
388- k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , True )
468+ k_nope , k_rope , v = self ._decompress_kv (
469+ kv ,
470+ infer_state ,
471+ layer_weight ,
472+ True ,
473+ infer_state .total_token_num ,
474+ infer_state .b_seq_len ,
475+ infer_state .max_value_in_b_seq_len ,
476+ infer_state .b_kv_start_loc ,
477+ )
389478 q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
390479 o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
391480 context_attention_fwd_with_v (
0 commit comments