1010 context_attention_fwd ,
1111 context_attention_fwd_no_prompt_cache ,
1212)
13+ from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad_fp8 import context_attention_fwd_fp8
1314from lightllm .models .deepseek2 .triton_kernel .context_flashattention_nopad_with_v import context_attention_fwd_with_v
1415from lightllm .models .deepseek2 .triton_kernel .sample_kv import sample_kv
1516
1617from lightllm .models .deepseek2 .triton_kernel .gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
18+ from lightllm .models .deepseek2 .triton_kernel .gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8
1719from lightllm .models .llama .layer_infer .transformer_layer_infer import LlamaTransformerLayerInfer
1820from lightllm .models .llama .triton_kernel .rmsnorm import rmsnorm_forward
1921from lightllm .models .llama .triton_kernel .silu_and_mul import silu_and_mul_fwd
2224from functools import partial
2325from lightllm .models .llama .yarn_rotary_utils import get_deepseek_mscale
2426import os
27+ from lightllm .common .quantization import vLLMFP8w8a8QuantizationMethod
2528
2629
2730class Deepseek2TransformerLayerInfer (LlamaTransformerLayerInfer ):
@@ -67,19 +70,12 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6770 self .enable_opt_decoding_mha = os .getenv ("ENABLE_OPT_DECODE_MHA" , "False" ).upper () in ["ON" , "TRUE" , "1" ]
6871 return
6972
70- def _bind_attention (self ):
71- if self .enable_cc_method :
72- self ._context_attention_kernel = partial (
73- Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
74- )
75- else :
76- self ._context_attention_kernel = partial (
77- Deepseek2TransformerLayerInfer ._context_attention_kernel_origin , self
78- )
79- self ._token_attention_kernel = partial (
80- Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
81- )
82- self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
73+ def _bind_func (self ):
74+ super ()._bind_func ()
75+ self ._bind_ffn ()
76+ return
77+
78+ def _bind_ffn (self ):
8379 if self .is_moe :
8480 if self .enable_dp :
8581 if os .environ .get ("MOE_MODE" , "TP" ) == "TP" :
@@ -92,6 +88,36 @@ def _bind_attention(self):
9288 else :
9389 self ._ffn = partial (LlamaTransformerLayerInfer ._ffn , self )
9490
91+ def _bind_attention (self ):
92+ if "triton_fp8kv" in self .mode :
93+ self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_fp8 , self )
94+ self ._token_attention_kernel = partial (
95+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding_fp8 , self
96+ )
97+ else :
98+ self ._copy_kv_to_mem_cache = partial (Deepseek2TransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
99+ self ._token_attention_kernel = partial (
100+ Deepseek2TransformerLayerInfer ._token_gqa_decode_attention_flashdecoding , self
101+ )
102+ if self .enable_cc_method :
103+ if "triton_fp8kv" in self .mode :
104+ self ._context_attention_kernel = partial (
105+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC_fp8 , self
106+ )
107+ else :
108+ self ._context_attention_kernel = partial (
109+ Deepseek2TransformerLayerInfer ._context_attention_kernel_with_CC , self
110+ )
111+ else :
112+ if "triton_fp8kv" in self .mode :
113+ self ._context_attention_kernel = partial (
114+ Deepseek2TransformerLayerInfer ._context_attention_kernel_origin_fp8 , self
115+ )
116+ else :
117+ self ._context_attention_kernel = partial (
118+ Deepseek2TransformerLayerInfer ._context_attention_kernel_origin , self
119+ )
120+
95121 def _get_qkv (
96122 self ,
97123 input : torch .Tensor ,
@@ -133,9 +159,19 @@ def _get_o(
133159 o_tensor = layer_weight .o_weight_ .mm (input .reshape (- 1 , self .tp_q_head_num_ * self .qk_nope_head_dim ))
134160 return o_tensor
135161
136- def _decompress_kv (self , kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight ):
162+ def _decompress_kv (
163+ self , kv , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , is_fp8
164+ ):
137165 if infer_state .use_dynamic_prompt_cache :
138- kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
166+ if is_fp8 :
167+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, :- 2 ].view (torch .float8_e4m3fn )
168+ kv_scale = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (torch .bfloat16 )
169+ k_scale = self .alloc_tensor ([infer_state .total_token_num , 1 ], dtype = kv_scale .dtype )
170+ else :
171+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
172+ kv_scale = None
173+ k_scale = None
174+
139175 compressed_kv = self .alloc_tensor (
140176 [infer_state .total_token_num , 1 , layer_weight .kv_lora_rank ], dtype = kv .dtype
141177 )
@@ -147,7 +183,12 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
147183 infer_state .b_req_idx ,
148184 infer_state .b_seq_len ,
149185 infer_state .req_manager .req_to_token_indexs ,
186+ kv_scale ,
187+ k_scale ,
150188 )
189+ if k_scale is not None :
190+ compressed_kv = compressed_kv .to (k_scale .dtype ) * k_scale .unsqueeze (- 1 )
191+ k_rope = k_rope .to (k_scale .dtype ) * k_scale .unsqueeze (- 1 )
151192 else :
152193 compressed_kv , k_rope = torch .split ( # (b*s, 1, kv_lora + qk_r)
153194 kv , [layer_weight .kv_lora_rank , layer_weight .qk_rope_head_dim ], dim = - 1
@@ -175,7 +216,33 @@ def _context_attention_kernel_with_CC(
175216 layer_weight : Deepseek2TransformerLayerWeight ,
176217 out = None ,
177218 ) -> torch .Tensor :
178- k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight )
219+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , False )
220+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
221+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
222+ context_attention_fwd_with_v (
223+ q_nope ,
224+ q_rope ,
225+ k_nope ,
226+ k_rope ,
227+ v ,
228+ o_tensor .view (- 1 , self .tp_q_head_num_ , q_nope .shape [- 1 ]),
229+ infer_state .b_start_loc ,
230+ infer_state .b_seq_len ,
231+ infer_state .b_ready_cache_len ,
232+ infer_state .max_len_in_batch ,
233+ self .softmax_scale ,
234+ )
235+ return o_tensor
236+
237+ def _context_attention_kernel_with_CC_fp8 (
238+ self ,
239+ q : torch .Tensor ,
240+ kv ,
241+ infer_state : Deepseek2InferStateInfo ,
242+ layer_weight : Deepseek2TransformerLayerWeight ,
243+ out = None ,
244+ ) -> torch .Tensor :
245+ k_nope , k_rope , v = self ._decompress_kv (kv , infer_state , layer_weight , True )
179246 q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
180247 o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
181248 context_attention_fwd_with_v (
@@ -235,6 +302,50 @@ def _context_attention_kernel_origin(
235302
236303 return o_tensor
237304
305+ def _context_attention_kernel_origin_fp8 (
306+ self ,
307+ q : torch .Tensor ,
308+ kv ,
309+ infer_state : Deepseek2InferStateInfo ,
310+ layer_weight : Deepseek2TransformerLayerWeight ,
311+ out = None ,
312+ ) -> torch .Tensor :
313+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
314+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
315+ o_tensor = self .alloc_tensor (q_nope .shape , dtype = q_nope .dtype ) if out is None else out
316+ if infer_state .use_dynamic_prompt_cache :
317+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, :- 2 ].view (torch .float8_e4m3fn )
318+ kv_scale = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (torch .bfloat16 )
319+ context_attention_fwd_fp8 (
320+ q_nope ,
321+ q_rope ,
322+ kv [:, :, : - self .qk_rope_head_dim ],
323+ kv [:, :, - self .qk_rope_head_dim :],
324+ kv_scale ,
325+ o_tensor .view (- 1 , self .tp_q_head_num_ , self .kv_lora_rank ),
326+ infer_state .b_req_idx ,
327+ infer_state .b_start_loc ,
328+ infer_state .b_seq_len ,
329+ infer_state .b_ready_cache_len ,
330+ infer_state .max_len_in_batch ,
331+ infer_state .req_manager .req_to_token_indexs ,
332+ self .softmax_scale ,
333+ )
334+ else :
335+ context_attention_fwd_no_prompt_cache (
336+ q_nope ,
337+ q_rope ,
338+ kv [:, :, : - self .qk_rope_head_dim ],
339+ kv [:, :, - self .qk_rope_head_dim :],
340+ o_tensor .view (- 1 , self .tp_q_head_num_ , self .kv_lora_rank ),
341+ infer_state .b_start_loc ,
342+ infer_state .b_seq_len ,
343+ infer_state .max_len_in_batch ,
344+ self .softmax_scale ,
345+ )
346+
347+ return o_tensor
348+
238349 def _token_gqa_decode_attention_flashdecoding (
239350 self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
240351 ):
@@ -277,6 +388,29 @@ def _token_gqa_decode_attention_flashdecoding(
277388 alloc_tensor_func = self .alloc_tensor ,
278389 )
279390
391+ def _token_gqa_decode_attention_flashdecoding_fp8 (
392+ self , q , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight , out = None
393+ ):
394+ q_nope , q_rope = q [:, :, : - self .qk_rope_head_dim ], q [:, :, - self .qk_rope_head_dim :]
395+ q_nope = layer_weight .k_b_proj_ .bmm (q_nope .transpose (0 , 1 )).transpose (0 , 1 )
396+
397+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, :- 2 ].view (torch .float8_e4m3fn )
398+ kv_scale = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :].view (torch .bfloat16 )
399+ return gqa_token_decode_attention_flash_decoding_fp8 (
400+ q_nope ,
401+ q_rope ,
402+ kv [:, :, : - self .qk_rope_head_dim ],
403+ kv [:, :, - self .qk_rope_head_dim :],
404+ kv_scale ,
405+ infer_state ,
406+ self .tp_q_head_num_ ,
407+ self .kv_lora_rank ,
408+ self .qk_rope_head_dim ,
409+ self .qk_nope_head_dim ,
410+ self .softmax_scale ,
411+ alloc_tensor_func = self .alloc_tensor ,
412+ )
413+
280414 def _splitfuse_attention_kernel (
281415 self , q , infer_state : SplitFuseInferStateInfo , layer_weight , out = None
282416 ) -> torch .Tensor :
@@ -319,6 +453,20 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
319453 )
320454 return
321455
456+ def _copy_kv_to_mem_cache_fp8 (self , buffer , mem_index , mem_manager ):
457+ quant_method = vLLMFP8w8a8QuantizationMethod ()
458+ quant , scale = quant_method .quantize_scaled_mm_fp8 (buffer .reshape (- 1 , buffer .shape [- 1 ]))
459+ destindex_copy_kv (
460+ quant .T .unsqueeze (1 )[:, :, : self .kv_lora_rank ].view (torch .uint8 ),
461+ quant .T .unsqueeze (1 )[:, :, self .kv_lora_rank :].view (torch .uint8 ),
462+ mem_index ,
463+ mem_manager .kv_buffer [self .layer_num_ ][:, :, : self .kv_lora_rank ],
464+ mem_manager .kv_buffer [self .layer_num_ ][:, :, self .kv_lora_rank : - 2 ],
465+ mem_manager .kv_buffer [self .layer_num_ ][:, :, - 2 :],
466+ scale .to (buffer .dtype ).view (torch .uint8 ),
467+ )
468+ return
469+
322470 def _ffn_dp (
323471 self , input , infer_state : Deepseek2InferStateInfo , layer_weight : Deepseek2TransformerLayerWeight
324472 ) -> torch .Tensor :
0 commit comments