1010 context_attention_fwd ,
1111 context_attention_fwd_no_prompt_cache ,
1212)
13+ from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad_fp8 import context_attention_fwd_fp8
1314from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad_with_v import context_attention_fwd_with_v
1415from lightllm .models .deepseek2 .triton_kernel .sample_kv import sample_kv
1516
1617from lightllm .models .deepseek2 .triton_kernel .gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
18+ from lightllm .models .deepseek2 .triton_kernel .gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8
1719from lightllm .models .llama .layer_infer .transformer_layer_infer import LlamaTransformerLayerInfer
1820from lightllm .models .llama .triton_kernel .rmsnorm import rmsnorm_forward
1921from lightllm .models .llama .triton_kernel .silu_and_mul import silu_and_mul_fwd
2224from functools import partial
2325from lightllm .models .llama .yarn_rotary_utils import get_deepseek_mscale
2426import os
27+ from lightllm .common .quantization import vLLMFP8w8a8QuantizationMethod
2528
2629
2730class Deepseek2TransformerLayerInfer (LlamaTransformerLayerInfer ):
@@ -67,19 +70,12 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6770 self .enable_opt_decoding_mha = os .getenv ("ENABLE_OPT_DECODE_MHA" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
6871 return
6972
70- def _bind_attention (self ):
71- if self .enable_cc_method :
72- self ._context_attention_kernel = partial (
73- Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
74- )
75- else :
76- self ._context_attention_kernel = partial (
77- Deepseek2TransformerLayerInfer ._context_attention_kernel_origin , self
78- )
79- self ._token_attention_kernel = partial (
80- Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
81- )
82- self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
73+ def _bind_func (self ):
74+ super ()._bind_func ()
75+ self ._bind_ffn ()
76+ return
77+
78+ def _bind_ffn (self ):
8379 if self .is_moe :
8480 if self .enable_dp :
8581 if os .environ .get ("MOE_MODE" , "TP" ) == "TP" :
@@ -92,6 +88,36 @@ def _bind_attention(self):
9288 else :
9389 self ._ffn = partial (LlamaTransformerLayerInfer ._ffn , self )
9490
91+ def _bind_attention (self ):
92+ if "triton_fp8kv" in self .mode :
93+ self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_fp8 , self )
94+ self ._token_attention_kernel = partial (
95+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding_fp8 , self
96+ )
97+ else :
98+ self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
99+ self ._token_attention_kernel = partial (
100+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
101+ )
102+ if self .enable_cc_method :
103+ if "triton_fp8kv" in self .mode :
104+ self ._context_attention_kernel = partial (
105+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC_fp8 , self
106+ )
107+ else :
108+ self ._context_attention_kernel = partial (
109+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
110+ )
111+ else :
112+ if "triton_fp8kv" in self .mode :
113+ self ._context_attention_kernel = partial (
114+ Deepseek2TransformerLayerInfer ._context_attention_kernel_origin_fp8 , self
115+ )
116+ else :
117+ self ._context_attention_kernel = partial (
118+ Deepseek2TransformerLayerInfer ._context_attention_kernel_origin , self
119+ )
120+
95121 def _get_qkv (
96122 self ,
97123 input : torch .Tensor ,
@@ -133,9 +159,19 @@ def _get_o(
133159 o_tensor = layer_weight .o_weight_ .mm (input .reshape (- 1 , self .tp_q_head_num_ * self .qk_nope_head_dim ))
134160 return o_tensor
135161
136- def _decompress_kv (self , kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight ):
162+ def _decompress_kv (
163+ self , kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , is_fp8
164+ ):
137165 if infer_state .use_dynamic_prompt_cache :
138- kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
166+ if is_fp8 :
167+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, :- 2 ].view (torch .float8_e4m3fn )
168+ kv_scale = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (torch .bfloat16 )
169+ k_scale = self .alloc_tensor ([infer_state .total_token_num , 1 ], dtype = kv_scale .dtype )
170+ else :
171+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
172+ kv_scale = None
173+ k_scale = None
174+
139175 compressed_kv = self .alloc_tensor (
140176 [infer_state .total_token_num , 1 , layer_weight .kv_lora_rank ], dtype = kv .dtype
141177 )
@@ -147,7 +183,12 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
147183 infer_state .b_req_idx ,
148184 infer_state .b_seq_len ,
149185 infer_state .req_manager .req_to_token_indexs ,
186+ kv_scale ,
187+ k_scale ,
150188 )
189+ if k_scale is not None :
190+ compressed_kv = compressed_kv .to (k_scale .dtype ) * k_scale .unsqueeze (- 1 )
191+ k_rope = k_rope .to (k_scale .dtype ) * k_scale .unsqueeze (- 1 )
151192 else :
152193 compressed_kv , k_rope = torch .split ( # (b*s, 1, kv_lora + qk_r)
153194 kv , [layer_weight .kv_lora_rank , layer_weight .qk_rope_head_dim ], dim = - 1
@@ -177,7 +218,33 @@ def _context_attention_kernel_with_CC(
177218 layer_weight : Deepseek2TransformerLayerWeight ,
178219 out = None ,
179220 ) -> torch .Tensor :
180- k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight )
221+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , False )
222+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
223+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
224+ context_attention_fwd_with_v (
225+ q_nope ,
226+ q_rope ,
227+ k_nope ,
228+ k_rope ,
229+ v ,
230+ o_tensor .view (- 1 , self .tp_q_head_num_ , q_nope .shape [- 1 ]),
231+ infer_state .b_start_loc ,
232+ infer_state .b_seq_len ,
233+ infer_state .b_ready_cache_len ,
234+ infer_state .max_len_in_batch ,
235+ self .softmax_scale ,
236+ )
237+ return o_tensor
238+
239+ def _context_attention_kernel_with_CC_fp8 (
240+ self ,
241+ q : torch .Tensor ,
242+ kv ,
243+ infer_state : Deepseek2InferStateInfo ,
244+ layer_weight : Deepseek2TransformerLayerWeight ,
245+ out = None ,
246+ ) -> torch .Tensor :
247+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , True )
181248 q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
182249 o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
183250 context_attention_fwd_with_v (
@@ -237,6 +304,50 @@ def _context_attention_kernel_origin(
237304
238305 return o_tensor
239306
307+ def _context_attention_kernel_origin_fp8 (
308+ self ,
309+ q : torch .Tensor ,
310+ kv ,
311+ infer_state : Deepseek2InferStateInfo ,
312+ layer_weight : Deepseek2TransformerLayerWeight ,
313+ out = None ,
314+ ) -> torch .Tensor :
315+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
316+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
317+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
318+ if infer_state .use_dynamic_prompt_cache :
319+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, :- 2 ].view (torch .float8_e4m3fn )
320+ kv_scale = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (torch .bfloat16 )
321+ context_attention_fwd_fp8 (
322+ q_nope ,
323+ q_rope ,
324+ kv [:, :, : - self .qk_rope_head_dim ],
325+ kv [:, :, - self .qk_rope_head_dim :],
326+ kv_scale ,
327+ o_tensor .view (- 1 , self .tp_q_head_num_ , self .kv_lora_rank ),
328+ infer_state .b_req_idx ,
329+ infer_state .b_start_loc ,
330+ infer_state .b_seq_len ,
331+ infer_state .b_ready_cache_len ,
332+ infer_state .max_len_in_batch ,
333+ infer_state .req_manager .req_to_token_indexs ,
334+ self .softmax_scale ,
335+ )
336+ else :
337+ context_attention_fwd_no_prompt_cache (
338+ q_nope ,
339+ q_rope ,
340+ kv [:, :, : - self .qk_rope_head_dim ],
341+ kv [:, :, - self .qk_rope_head_dim :],
342+ o_tensor .view (- 1 , self .tp_q_head_num_ , self .kv_lora_rank ),
343+ infer_state .b_start_loc ,
344+ infer_state .b_seq_len ,
345+ infer_state .max_len_in_batch ,
346+ self .softmax_scale ,
347+ )
348+
349+ return o_tensor
350+
240351 def _token_gqa_decode_attention_flashdecoding (
241352 self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
242353 ):
@@ -279,6 +390,29 @@ def _token_gqa_decode_attention_flashdecoding(
279390 alloc_tensor_func = self .alloc_tensor ,
280391 )
281392
393+ def _token_gqa_decode_attention_flashdecoding_fp8 (
394+ self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
395+ ):
396+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
397+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
398+
399+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, :- 2 ].view (torch .float8_e4m3fn )
400+ kv_scale = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (torch .bfloat16 )
401+ return gqa_token_decode_attention_flash_decoding_fp8 (
402+ q_nope ,
403+ q_rope ,
404+ kv [:, :, : - self .qk_rope_head_dim ],
405+ kv [:, :, - self .qk_rope_head_dim :],
406+ kv_scale ,
407+ infer_state ,
408+ self .tp_q_head_num_ ,
409+ self .kv_lora_rank ,
410+ self .qk_rope_head_dim ,
411+ self .qk_nope_head_dim ,
412+ self .softmax_scale ,
413+ alloc_tensor_func = self .alloc_tensor ,
414+ )
415+
282416 def _splitfuse_attention_kernel (
283417 self , q , infer_state : SplitFuseInferStateInfo , layer_weight , out = None
284418 ) -> torch .Tensor :
@@ -321,6 +455,20 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
321455 )
322456 return
323457
458+ def _copy_kv_to_mem_cache_fp8 (self , buffer , mem_index , mem_manager ):
459+ quant_method = vLLMFP8w8a8QuantizationMethod ()
460+ quant , scale = quant_method .quantize_scaled_mm_fp8 (buffer .reshape (- 1 , buffer .shape [- 1 ]))
461+ destindex_copy_kv (
462+ quant .T .unsqueeze (1 )[:, :, : self .kv_lora_rank ].view (torch .uint8 ),
463+ quant .T .unsqueeze (1 )[:, :, self .kv_lora_rank :].view (torch .uint8 ),
464+ mem_index ,
465+ mem_manager .kv_buffer [self .layer_num_ ][:, :, : self .kv_lora_rank ],
466+ mem_manager .kv_buffer [self .layer_num_ ][:, :, self .kv_lora_rank : - 2 ],
467+ mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :],
468+ scale .to (buffer .dtype ).view (torch .uint8 ),
469+ )
470+ return
471+
324472 def _ffn_dp (
325473 self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
326474 ) -> torch .Tensor :
0 commit comments