2929from lightllm .distributed .communication_op import all_gather , all_gather_into_tensor , all_reduce , reduce_scatter_tensor
3030from lightllm .utils .envs_utils import get_env_start_args
3131from lightllm .utils .dist_utils import get_global_world_size
32+ from lightllm .utils .log_utils import init_logger
33+
34+ logger = init_logger (__name__ )
35+
36+ try :
37+ from sgl_kernel .flash_attn import flash_attn_varlen_func , flash_attn_with_kvcache
38+ except :
39+ logger .warning ("sgl_kernel is not installed, or the installed version does not support fa3!" )
3240
3341
3442class Deepseek2TransformerLayerInfer (LlamaTransformerLayerInfer ):
@@ -93,7 +101,11 @@ def _bind_attention(self):
93101 )
94102 else :
95103 self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
96- if get_env_start_args ().enable_flashinfer_decode :
104+ if get_env_start_args ().enable_fa3 :
105+ self ._token_attention_kernel = partial (
106+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashattention , self
107+ )
108+ elif get_env_start_args ().enable_flashinfer_decode :
97109 self ._token_attention_kernel = partial (
98110 Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashinfer , self
99111 )
@@ -112,7 +124,11 @@ def _bind_attention(self):
112124 Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC_fp8 , self
113125 )
114126 else :
115- if get_env_start_args ().enable_flashinfer_prefill :
127+ if get_env_start_args ().enable_fa3 :
128+ self ._context_attention_kernel = partial (
129+ Deepseek2TransformerLayerInfer ._context_attention_flashattention_kernel_with_CC , self
130+ )
131+ elif get_env_start_args ().enable_flashinfer_prefill :
116132 self ._context_attention_kernel = partial (
117133 Deepseek2TransformerLayerInfer ._context_attention_flashinfer_kernel_with_CC , self
118134 )
@@ -278,6 +294,30 @@ def _decompress_kv(
278294 k_nope , v = torch .split (kv_nope , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
279295 return k_nope , k_rope , v
280296
297+ def _context_attention_flashattention_kernel_with_CC (
298+ self ,
299+ q : torch .Tensor ,
300+ kv ,
301+ infer_state : Deepseek2FlashInferStateInfo ,
302+ layer_weight : Deepseek2TransformerLayerWeight ,
303+ out = None ,
304+ ) -> torch .Tensor :
305+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , False )
306+ k = torch .cat ([k_nope , torch .repeat_interleave (k_rope , self .tp_q_head_num_ , dim = - 2 )], dim = - 1 )
307+ o_tensor = flash_attn_varlen_func (
308+ q = q .view (- 1 , self .tp_q_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim ),
309+ k = k .view (- 1 , self .tp_k_head_num_ , self .qk_nope_head_dim + self .qk_rope_head_dim ),
310+ v = v .view (- 1 , self .tp_v_head_num_ , self .v_head_dim ),
311+ cu_seqlens_q = infer_state .cu_seqlens_q ,
312+ cu_seqlens_k = infer_state .cu_seqlens_k ,
313+ max_seqlen_q = infer_state .q_max_seq_len ,
314+ max_seqlen_k = infer_state .max_seq_len ,
315+ softmax_scale = self .softmax_scale ,
316+ causal = True ,
317+ return_softmax_lse = False ,
318+ )
319+ return o_tensor
320+
281321 def _context_attention_flashinfer_kernel_with_CC (
282322 self ,
283323 q : torch .Tensor ,
@@ -450,6 +490,35 @@ def _context_attention_kernel_origin_fp8(
450490
451491 return o_tensor
452492
493+ def _token_gqa_decode_attention_flashattention (
494+ self , q , infer_state : Deepseek2FlashInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
495+ ):
496+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
497+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
498+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
499+ k_rope = kv [:, :, - self .qk_rope_head_dim :].reshape (- 1 , 1 , 1 , self .qk_rope_head_dim )
500+ kv_nope = kv [:, :, : - self .qk_rope_head_dim ].reshape (- 1 , 1 , 1 , self .kv_lora_rank )
501+ k_descale , v_descale = None , None
502+ o_tensor = flash_attn_with_kvcache (
503+ q = q_rope ,
504+ k_cache = k_rope ,
505+ v_cache = kv_nope ,
506+ qv = q_nope ,
507+ page_table = infer_state .page_table ,
508+ cache_seqlens = infer_state .b_seq_len ,
509+ cu_seqlens_q = infer_state .cu_seqlens_q ,
510+ cu_seqlens_k_new = infer_state .cu_seqlens_k ,
511+ max_seqlen_q = 1 ,
512+ softmax_scale = self .softmax_scale ,
513+ causal = True ,
514+ window_size = (- 1 , - 1 ),
515+ softcap = 0.0 ,
516+ k_descale = k_descale ,
517+ v_descale = v_descale ,
518+ return_softmax_lse = False ,
519+ )
520+ return o_tensor
521+
453522 def _token_gqa_decode_attention_flashinfer (
454523 self , q , infer_state : Deepseek2FlashInferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
455524 ):
0 commit comments