2020from lightllm .models .llama .triton_kernel .silu_and_mul import silu_and_mul_fwd
2121
2222from lightllm .models .llama .infer_struct import LlamaInferStateInfo
23+ from lightllm .models .llama .flashinfer_struct import LlamaFlashInferStateInfo
2324from lightllm .common .basemodel .triton_kernel .destindex_copy_kv import destindex_copy_kv , destindex_copy_quantize_kv
2425from lightllm .common .basemodel import TransformerLayerInferTpl
2526from lightllm .models .llama .triton_kernel .ppl_quant_copy_kv import destindex_copy_dequantize_kv
@@ -68,8 +69,12 @@ def _bind_attention(self):
6869 )
6970 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
7071 return
71-
72- self ._context_attention_kernel = partial (LlamaTransformerLayerInfer ._context_attention_kernel , self )
72+ elif get_env_start_args ().enable_flashinfer_prefill :
73+ self ._context_attention_kernel = partial (
74+ LlamaTransformerLayerInfer ._context_attention_flashinfer_kernel , self
75+ )
76+ else :
77+ self ._context_attention_kernel = partial (LlamaTransformerLayerInfer ._context_attention_kernel , self )
7378 if "ppl_int8kv" in self .mode :
7479 self ._token_attention_kernel = partial (LlamaTransformerLayerInfer ._token_decode_attention_ppl_int8kv , self )
7580 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_ppl_int8kv , self )
@@ -119,7 +124,12 @@ def _bind_attention(self):
119124 )
120125 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
121126 else :
122- self ._token_attention_kernel = partial (LlamaTransformerLayerInfer ._token_decode_attention_normal , self )
127+ if get_env_start_args ().enable_flashinfer_decode :
128+ self ._token_attention_kernel = partial (
129+ LlamaTransformerLayerInfer ._token_decode_attention_flashinfer , self
130+ )
131+ else :
132+ self ._token_attention_kernel = partial (LlamaTransformerLayerInfer ._token_decode_attention_normal , self )
123133 self ._copy_kv_to_mem_cache = partial (LlamaTransformerLayerInfer ._copy_kv_to_mem_cache_normal , self )
124134
125135 return
@@ -178,6 +188,28 @@ def _tpsp_get_qkv(
178188 )
179189 return q , cache_kv
180190
191+ def _context_attention_flashinfer_kernel (
192+ self , q , kv , infer_state : LlamaFlashInferStateInfo , layer_weight , out = None
193+ ) -> torch .Tensor :
194+ o_tensor = self .alloc_tensor (q .shape , q .dtype ) if out is None else out
195+ if infer_state .use_dynamic_prompt_cache :
196+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ]
197+ kv = kv .unsqueeze (1 )
198+ infer_state .prefill_wrapper .run (
199+ q .view (q .shape [0 ], - 1 , self .head_dim_ ),
200+ (kv [:, :, : self .tp_k_head_num_ , :], kv [:, :, self .tp_k_head_num_ :, :]),
201+ out = o_tensor .view (q .shape [0 ], - 1 , self .head_dim_ ),
202+ )
203+ else :
204+ infer_state .prefill_wrapper .run (
205+ q .view (q .shape [0 ], - 1 , self .head_dim_ ),
206+ kv [:, : self .tp_k_head_num_ , :],
207+ kv [:, self .tp_k_head_num_ :, :],
208+ out = o_tensor .view (q .shape [0 ], - 1 , self .head_dim_ ),
209+ )
210+
211+ return o_tensor
212+
181213 def _context_attention_kernel (
182214 self , q , kv , infer_state : LlamaInferStateInfo , layer_weight , out = None
183215 ) -> torch .Tensor :
@@ -254,7 +286,6 @@ def _context_attention_kernel_ppl_int8kv(
254286 return o_tensor
255287
256288 def _context_attention_flashattention (self , q , kv , infer_state : LlamaInferStateInfo , layer_weight , out = None ):
257-
258289 cache_k = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, 0 : self .tp_k_head_num_ , :].reshape (
259290 - 1 , 1 , self .tp_k_head_num_ , self .head_dim_
260291 )
@@ -264,7 +295,7 @@ def _context_attention_flashattention(self, q, kv, infer_state: LlamaInferStateI
264295 q = q .reshape (- 1 , self .tp_q_head_num_ , self .head_dim_ )
265296 k_descale , v_descale = None , None # disable quantization
266297 Lq = q .shape [- 1 ]
267- sm_scale = 1.0 / (Lq ** 0.5 )
298+ sm_scale = 1.0 / (Lq ** 0.5 )
268299 o = flash_attn_with_kvcache (
269300 q = q ,
270301 k_cache = cache_k ,
@@ -392,6 +423,19 @@ def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager):
392423 )
393424 return
394425
426+ def _token_decode_attention_flashinfer (self , q , infer_state : LlamaFlashInferStateInfo , layer_weight , out = None ):
427+ batch_size = infer_state .batch_size
428+ calcu_shape1 = (batch_size , self .tp_q_head_num_ , self .head_dim_ )
429+
430+ o_tensor = self .alloc_tensor (q .shape , q .dtype ) if out is None else out
431+ kv = infer_state .mem_manager .kv_buffer [self .layer_num_ ].unsqueeze (1 )
432+ infer_state .decode_wrapper .run (
433+ q .view (calcu_shape1 ),
434+ (kv [:, :, : self .tp_k_head_num_ , :], kv [:, :, self .tp_k_head_num_ :, :]),
435+ out = o_tensor .view (calcu_shape1 ),
436+ )
437+ return o_tensor
438+
395439 def _token_decode_attention_normal (self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None ):
396440 total_token_num = infer_state .total_token_num
397441 batch_size = infer_state .batch_size
@@ -565,7 +609,7 @@ def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo,
565609 # at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch)
566610 fp16_decode_attention (
567611 o_tensor .view (calcu_shape1 ),
568- 1.0 / (self .head_dim_ ** 0.5 ),
612+ 1.0 / (self .head_dim_ ** 0.5 ),
569613 q .view (calcu_shape1 ),
570614 infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, 0 : self .tp_k_head_num_ , :],
571615 infer_state .mem_manager .kv_buffer [self .layer_num_ ][
@@ -673,7 +717,6 @@ def _token_decode_attention_gqa_flashdecoding_vsm(
673717 )
674718
675719 def _token_decode_attention_flashattention (self , q , infer_state : LlamaInferStateInfo , layer_weight , out = None ):
676-
677720 cache_k = infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, 0 : self .tp_k_head_num_ , :].reshape (
678721 - 1 , 1 , self .tp_k_head_num_ , self .head_dim_
679722 )
@@ -683,7 +726,7 @@ def _token_decode_attention_flashattention(self, q, infer_state: LlamaInferState
683726 q = q .reshape (- 1 , self .tp_q_head_num_ , self .head_dim_ )
684727 k_descale , v_descale = None , None # disable quantization
685728 Lq = q .shape [- 1 ]
686- sm_scale = 1.0 / (Lq ** 0.5 )
729+ sm_scale = 1.0 / (Lq ** 0.5 )
687730 o = flash_attn_with_kvcache (
688731 q = q ,
689732 k_cache = cache_k ,
@@ -711,7 +754,6 @@ def overlap_tpsp_token_forward(
711754 infer_state1 : LlamaInferStateInfo ,
712755 layer_weight : LlamaTransformerLayerWeight ,
713756 ):
714-
715757 input_embdings = self .tpsp_token_forward (input_embdings , infer_state , layer_weight = layer_weight )
716758 input_embdings1 = self .tpsp_token_forward (input_embdings1 , infer_state1 , layer_weight = layer_weight )
717759 return input_embdings , input_embdings1
0 commit comments