55import numpy as np
66from lightllm .models .deepseek2 .layer_weights .transformer_layer_weight import Deepseek2TransformerLayerWeight
77from lightllm .models .deepseek2 .triton_kernel .destindex_copy_kv import destindex_copy_kv
8+ from lightllm .models .deepseek2 .triton_kernel .destindex_copy_kv_fp8 import destindex_copy_kv_fp8
89from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad import (
910 context_attention_fwd ,
1011 context_attention_fwd_no_prompt_cache ,
2324from functools import partial
2425from lightllm .models .llama .yarn_rotary_utils import get_deepseek_mscale
2526import os
26- from lightllm .common . quantization import vLLMFP8w8a8QuantizationMethod
27+ from lightllm .utils . envs_utils import enable_env_vars
2728
2829
2930class Deepseek2TransformerLayerInfer (LlamaTransformerLayerInfer ):
@@ -67,7 +68,6 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6768 self .tp_o_head_num_ = self .tp_q_head_num_
6869 self .num_heads = network_config ["num_attention_heads" ]
6970 self .num_kv_heads = network_config ["num_key_value_heads" ]
70- self .enable_opt_decoding_mha = os .getenv ("ENABLE_OPT_DECODE_MHA" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
7171 return
7272
7373 def _bind_func (self ):
@@ -96,18 +96,33 @@ def _bind_attention(self):
9696 )
9797 else :
9898 self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
99- self ._token_attention_kernel = partial (
100- Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
101- )
102- if self .enable_cc_method :
103- if "triton_fp8kv" in self .mode :
104- self ._context_attention_kernel = partial (
105- Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC_fp8 , self
99+ if enable_env_vars ("ENABLE_FLASHINFER_DECODE_MLA" ):
100+ self ._token_attention_kernel = partial (
101+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashinfer , self
106102 )
107103 else :
108- self ._context_attention_kernel = partial (
109- Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
104+ self ._token_attention_kernel = partial (
105+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
110106 )
107+ if self .enable_cc_method :
108+ if "triton_fp8kv" in self .mode :
109+ if enable_env_vars ("ENABLE_FLASHINFER_PREFILLED" ):
110+ self ._context_attention_kernel = partial (
111+ Deepseek2TransformerLayerInfer ._context_attention_flashinfer_kernel_with_CC_fp8 , self
112+ )
113+ else :
114+ self ._context_attention_kernel = partial (
115+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC_fp8 , self
116+ )
117+ else :
118+ if enable_env_vars ("ENABLE_FLASHINFER_PREFILLED" ):
119+ self ._context_attention_kernel = partial (
120+ Deepseek2TransformerLayerInfer ._context_attention_flashinfer_kernel_with_CC , self
121+ )
122+ else :
123+ self ._context_attention_kernel = partial (
124+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
125+ )
111126 else :
112127 if "triton_fp8kv" in self .mode :
113128 self ._context_attention_kernel = partial (
@@ -205,6 +220,38 @@ def _decompress_kv(
205220 k_nope , v = torch .split (kv_nope , [self .qk_nope_head_dim , self .v_head_dim ], dim = - 1 )
206221 return k_nope , k_rope , v
207222
223+ def _context_attention_flashinfer_kernel_with_CC (
224+ self ,
225+ q : torch .Tensor ,
226+ kv ,
227+ infer_state : Deepseek2InferStateInfo ,
228+ layer_weight : Deepseek2TransformerLayerWeight ,
229+ out = None ,
230+ ) -> torch .Tensor :
231+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , False )
232+ o_tensor = (
233+ self .alloc_tensor ((q .shape [0 ], q .shape [1 ], self .qk_nope_head_dim ), dtype = q .dtype ) if out is None else out
234+ )
235+ k = torch .cat ([k_nope , torch .repeat_interleave (k_rope , self .tp_q_head_num_ , dim = - 2 )], dim = - 1 )
236+ infer_state .prefill_wrapper .run (q , k , v , out = o_tensor )
237+ return o_tensor
238+
239+ def _context_attention_flashinfer_kernel_with_CC_fp8 (
240+ self ,
241+ q : torch .Tensor ,
242+ kv ,
243+ infer_state : Deepseek2InferStateInfo ,
244+ layer_weight : Deepseek2TransformerLayerWeight ,
245+ out = None ,
246+ ) -> torch .Tensor :
247+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , True )
248+ o_tensor = (
249+ self .alloc_tensor ((q .shape [0 ], q .shape [1 ], self .qk_nope_head_dim ), dtype = q .dtype ) if out is None else out
250+ )
251+ k = torch .cat ([k_nope , torch .repeat_interleave (k_rope , self .tp_q_head_num_ , dim = - 2 )], dim = - 1 )
252+ infer_state .prefill_wrapper .run (q , k , v , out = o_tensor )
253+ return o_tensor
254+
208255 def _context_attention_kernel_with_CC (
209256 self ,
210257 q : torch .Tensor ,
@@ -345,6 +392,25 @@ def _context_attention_kernel_origin_fp8(
345392
346393 return o_tensor
347394
395+ def _token_gqa_decode_attention_flashinfer (
396+ self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
397+ ):
398+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
399+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
400+
401+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
402+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype )
403+
404+ infer_state .decode_wrapper .run (
405+ q_nope ,
406+ q_rope ,
407+ kv [:, :, : - self .qk_rope_head_dim ],
408+ kv [:, :, - self .qk_rope_head_dim :],
409+ out = o_tensor ,
410+ return_lse = False ,
411+ )
412+ return o_tensor
413+
348414 def _token_gqa_decode_attention_flashdecoding (
349415 self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
350416 ):
@@ -354,7 +420,7 @@ def _token_gqa_decode_attention_flashdecoding(
354420 kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
355421 o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype )
356422
357- if self . enable_opt_decoding_mha :
423+ if enable_env_vars ( "ENABLE_OPT_DECODE_MHA" ) :
358424 q = torch .cat ([q_nope , q_rope ], dim = - 1 )
359425 q_nope , q_rope = None , None
360426 import lightllm_ppl_mla
@@ -368,7 +434,7 @@ def _token_gqa_decode_attention_flashdecoding(
368434 infer_state .b_req_idx ,
369435 self .softmax_scale ,
370436 q .shape [- 1 ],
371- q_nope . shape [ - 1 ] ,
437+ self . kv_lora_rank ,
372438 )
373439 return o_tensor
374440 else :
@@ -421,16 +487,13 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
421487 return
422488
423489 def _copy_kv_to_mem_cache_fp8 (self , buffer , mem_index , mem_manager ):
424- quant_method = vLLMFP8w8a8QuantizationMethod ()
425- quant , scale = quant_method .quantize_scaled_mm_fp8 (buffer .reshape (- 1 , buffer .shape [- 1 ]))
426- destindex_copy_kv (
427- quant .T .unsqueeze (1 )[:, :, : self .kv_lora_rank ].view (torch .uint8 ),
428- quant .T .unsqueeze (1 )[:, :, self .kv_lora_rank :].view (torch .uint8 ),
490+ destindex_copy_kv_fp8 (
491+ buffer [:, :, : self .kv_lora_rank ],
492+ buffer [:, :, self .kv_lora_rank :],
429493 mem_index ,
430- mem_manager .kv_buffer [self .layer_num_ ][:, :, : self .kv_lora_rank ],
431- mem_manager .kv_buffer [self .layer_num_ ][:, :, self .kv_lora_rank : - 2 ],
432- mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :],
433- scale .to (buffer .dtype ).view (torch .uint8 ),
494+ mem_manager .kv_buffer [self .layer_num_ ][:, :, : self .kv_lora_rank ].view (torch .float8_e4m3fn ),
495+ mem_manager .kv_buffer [self .layer_num_ ][:, :, self .kv_lora_rank : - 2 ].view (torch .float8_e4m3fn ),
496+ mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (buffer .dtype ),
434497 )
435498 return
436499
0 commit comments