88from lightllm .models .deepseek2 .triton_kernel .destindex_copy_kv import destindex_copy_kv
99from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad import (
1010 context_attention_fwd ,
11+ context_attention_fwd_fp8 ,
1112 context_attention_fwd_no_prompt_cache ,
1213)
1314from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad_with_v import context_attention_fwd_with_v
1415from lightllm .models .deepseek2 .triton_kernel .sample_kv import sample_kv
1516
16- from lightllm .models .deepseek2 .triton_kernel .gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
17+ from lightllm .models .deepseek2 .triton_kernel .gqa_flash_decoding import (
18+ gqa_token_decode_attention_flash_decoding ,
19+ gqa_token_decode_attention_flash_decoding_fp8 ,
20+ )
1721from lightllm .models .llama .layer_infer .transformer_layer_infer import LlamaTransformerLayerInfer
1822from lightllm .models .llama .triton_kernel .rmsnorm import rmsnorm_forward
1923from lightllm .models .llama .triton_kernel .silu_and_mul import silu_and_mul_fwd
2226from functools import partial
2327from lightllm .models .llama .yarn_rotary_utils import get_deepseek_mscale
2428import os
29+ from lightllm .common .quantization import vLLMFP8w8a8QuantizationMethod
2530
2631
2732class Deepseek2TransformerLayerInfer (LlamaTransformerLayerInfer ):
@@ -67,19 +72,12 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6772 self .enable_opt_decoding_mha = os .getenv ("ENABLE_OPT_DECODE_MHA" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
6873 return
6974
70- def _bind_attention (self ):
71- if self .enable_cc_method :
72- self ._context_attention_kernel = partial (
73- Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
74- )
75- else :
76- self ._context_attention_kernel = partial (
77- Deepseek2TransformerLayerInfer ._context_attention_kernel_origin , self
78- )
79- self ._token_attention_kernel = partial (
80- Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
81- )
82- self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
75+ def _bind_func (self ):
76+ super ()._bind_func ()
77+ self ._bind_ffn ()
78+ return
79+
80+ def _bind_ffn (self ):
8381 if self .is_moe :
8482 if self .enable_dp :
8583 if os .environ .get ("MOE_MODE" , "TP" ) == "TP" :
@@ -92,6 +90,36 @@ def _bind_attention(self):
9290 else :
9391 self ._ffn = partial (LlamaTransformerLayerInfer ._ffn , self )
9492
93+ def _bind_attention (self ):
94+ if "triton_fp8kv" in self .mode :
95+ self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_fp8 , self )
96+ self ._token_attention_kernel = partial (
97+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding_fp8 , self
98+ )
99+ else :
100+ self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
101+ self ._token_attention_kernel = partial (
102+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
103+ )
104+ if self .enable_cc_method :
105+ if "triton_fp8kv" in self .mode :
106+ self ._context_attention_kernel = partial (
107+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC_fp8 , self
108+ )
109+ else :
110+ self ._context_attention_kernel = partial (
111+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
112+ )
113+ else :
114+ if "triton_fp8kv" in self .mode :
115+ self ._context_attention_kernel = partial (
116+ Deepseek2TransformerLayerInfer ._context_attention_kernel_origin_fp8 , self
117+ )
118+ else :
119+ self ._context_attention_kernel = partial (
120+ Deepseek2TransformerLayerInfer ._context_attention_kernel_origin , self
121+ )
122+
95123 def _get_qkv (
96124 self ,
97125 input : torch .Tensor ,
@@ -133,9 +161,19 @@ def _get_o(
133161 o_tensor = layer_weight .o_weight_ .mm (input .reshape (- 1 , self .tp_q_head_num_ * self .qk_nope_head_dim ))
134162 return o_tensor
135163
136- def _decompress_kv (self , kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight ):
164+ def _decompress_kv (
165+ self , kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , is_fp8
166+ ):
137167 if infer_state .use_dynamic_prompt_cache :
138- kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
168+ if is_fp8 :
169+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, :- 2 ].view (torch .float8_e4m3fn )
170+ kv_scale = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (torch .bfloat16 )
171+ k_scale = self .alloc_tensor ([infer_state .total_token_num , 1 ], dtype = kv_scale .dtype )
172+ else :
173+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
174+ kv_scale = None
175+ k_scale = None
176+
139177 compressed_kv = self .alloc_tensor (
140178 [infer_state .total_token_num , 1 , layer_weight .kv_lora_rank ], dtype = kv .dtype
141179 )
@@ -147,7 +185,12 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
147185 infer_state .b_req_idx ,
148186 infer_state .b_seq_len ,
149187 infer_state .req_manager .req_to_token_indexs ,
188+ kv_scale ,
189+ k_scale ,
150190 )
191+ if k_scale is not None :
192+ compressed_kv = compressed_kv .to (k_scale .dtype ) * k_scale .unsqueeze (- 1 )
193+ k_rope = k_rope .to (k_scale .dtype ) * k_scale .unsqueeze (- 1 )
151194 else :
152195 compressed_kv , k_rope = torch .split ( # (b*s, 1, kv_lora + qk_r)
153196 kv , [layer_weight .kv_lora_rank , layer_weight .qk_rope_head_dim ], dim = - 1
@@ -177,7 +220,33 @@ def _context_attention_kernel_with_CC(
177220 layer_weight : Deepseek2TransformerLayerWeight ,
178221 out = None ,
179222 ) -> torch .Tensor :
180- k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight )
223+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , False )
224+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
225+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
226+ context_attention_fwd_with_v (
227+ q_nope ,
228+ q_rope ,
229+ k_nope ,
230+ k_rope ,
231+ v ,
232+ o_tensor .view (- 1 , self .tp_q_head_num_ , q_nope .shape [- 1 ]),
233+ infer_state .b_start_loc ,
234+ infer_state .b_seq_len ,
235+ infer_state .b_ready_cache_len ,
236+ infer_state .max_len_in_batch ,
237+ self .softmax_scale ,
238+ )
239+ return o_tensor
240+
241+ def _context_attention_kernel_with_CC_fp8 (
242+ self ,
243+ q : torch .Tensor ,
244+ kv ,
245+ infer_state : Deepseek2InferStateInfo ,
246+ layer_weight : Deepseek2TransformerLayerWeight ,
247+ out = None ,
248+ ) -> torch .Tensor :
249+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , True )
181250 q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
182251 o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
183252 context_attention_fwd_with_v (
@@ -237,6 +306,50 @@ def _context_attention_kernel_origin(
237306
238307 return o_tensor
239308
309+ def _context_attention_kernel_origin_fp8 (
310+ self ,
311+ q : torch .Tensor ,
312+ kv ,
313+ infer_state : Deepseek2InferStateInfo ,
314+ layer_weight : Deepseek2TransformerLayerWeight ,
315+ out = None ,
316+ ) -> torch .Tensor :
317+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
318+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
319+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
320+ if infer_state .use_dynamic_prompt_cache :
321+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, :- 2 ].view (torch .float8_e4m3fn )
322+ kv_scale = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (torch .bfloat16 )
323+ context_attention_fwd_fp8 (
324+ q_nope ,
325+ q_rope ,
326+ kv [:, :, : - self .qk_rope_head_dim ],
327+ kv [:, :, - self .qk_rope_head_dim :],
328+ kv_scale ,
329+ o_tensor .view (- 1 , self .tp_q_head_num_ , self .kv_lora_rank ),
330+ infer_state .b_req_idx ,
331+ infer_state .b_start_loc ,
332+ infer_state .b_seq_len ,
333+ infer_state .b_ready_cache_len ,
334+ infer_state .max_len_in_batch ,
335+ infer_state .req_manager .req_to_token_indexs ,
336+ self .softmax_scale ,
337+ )
338+ else :
339+ context_attention_fwd_no_prompt_cache (
340+ q_nope ,
341+ q_rope ,
342+ kv [:, :, : - self .qk_rope_head_dim ],
343+ kv [:, :, - self .qk_rope_head_dim :],
344+ o_tensor .view (- 1 , self .tp_q_head_num_ , self .kv_lora_rank ),
345+ infer_state .b_start_loc ,
346+ infer_state .b_seq_len ,
347+ infer_state .max_len_in_batch ,
348+ self .softmax_scale ,
349+ )
350+
351+ return o_tensor
352+
240353 def _token_gqa_decode_attention_flashdecoding (
241354 self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
242355 ):
@@ -279,6 +392,29 @@ def _token_gqa_decode_attention_flashdecoding(
279392 alloc_tensor_func = self .alloc_tensor ,
280393 )
281394
395+ def _token_gqa_decode_attention_flashdecoding_fp8 (
396+ self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
397+ ):
398+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
399+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
400+
401+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, :- 2 ].view (torch .float8_e4m3fn )
402+ kv_scale = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (torch .bfloat16 )
403+ return gqa_token_decode_attention_flash_decoding_fp8 (
404+ q_nope ,
405+ q_rope ,
406+ kv [:, :, : - self .qk_rope_head_dim ],
407+ kv [:, :, - self .qk_rope_head_dim :],
408+ kv_scale ,
409+ infer_state ,
410+ self .tp_q_head_num_ ,
411+ self .kv_lora_rank ,
412+ self .qk_rope_head_dim ,
413+ self .qk_nope_head_dim ,
414+ self .softmax_scale ,
415+ alloc_tensor_func = self .alloc_tensor ,
416+ )
417+
282418 def _splitfuse_attention_kernel (
283419 self , q , infer_state : SplitFuseInferStateInfo , layer_weight , out = None
284420 ) -> torch .Tensor :
@@ -321,6 +457,20 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
321457 )
322458 return
323459
460+ def _copy_kv_to_mem_cache_fp8 (self , buffer , mem_index , mem_manager ):
461+ quant_method = vLLMFP8w8a8QuantizationMethod ()
462+ quant , scale = quant_method .quantize_scaled_mm_fp8 (buffer .reshape (- 1 , buffer .shape [- 1 ]))
463+ destindex_copy_kv (
464+ quant .T .unsqueeze (1 )[:, :, : self .kv_lora_rank ].view (torch .uint8 ),
465+ quant .T .unsqueeze (1 )[:, :, self .kv_lora_rank :].view (torch .uint8 ),
466+ mem_index ,
467+ mem_manager .kv_buffer [self .layer_num_ ][:, :, : self .kv_lora_rank ],
468+ mem_manager .kv_buffer [self .layer_num_ ][:, :, self .kv_lora_rank : - 2 ],
469+ mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :],
470+ scale .to (buffer .dtype ).view (torch .uint8 ),
471+ )
472+ return
473+
324474 def _ffn_dp (
325475 self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
326476 ) -> torch .Tensor :
0 commit comments