|
20 | 20 | from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd |
21 | 21 |
|
22 | 22 | from lightllm.models.llama.infer_struct import LlamaInferStateInfo |
| 23 | +from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo |
23 | 24 | from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv |
24 | 25 | from lightllm.common.basemodel import TransformerLayerInferTpl |
25 | 26 | from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv |
@@ -68,8 +69,12 @@ def _bind_attention(self): |
68 | 69 | ) |
69 | 70 | self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) |
70 | 71 | return |
71 | | - |
72 | | - self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) |
| 72 | + elif get_env_start_args().enable_flashinfer_prefill: |
| 73 | + self._context_attention_kernel = partial( |
| 74 | + LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self |
| 75 | + ) |
| 76 | + else: |
| 77 | + self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) |
73 | 78 | if "ppl_int8kv" in self.mode: |
74 | 79 | self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv, self) |
75 | 80 | self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) |
@@ -119,7 +124,12 @@ def _bind_attention(self): |
119 | 124 | ) |
120 | 125 | self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) |
121 | 126 | else: |
122 | | - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self) |
| 127 | + if get_env_start_args().enable_flashinfer_decode: |
| 128 | + self._token_attention_kernel = partial( |
| 129 | + LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self |
| 130 | + ) |
| 131 | + else: |
| 132 | + self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self) |
123 | 133 | self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) |
124 | 134 |
|
125 | 135 | return |
@@ -178,6 +188,28 @@ def _tpsp_get_qkv( |
178 | 188 | ) |
179 | 189 | return q, cache_kv |
180 | 190 |
|
| 191 | + def _context_attention_flashinfer_kernel( |
| 192 | + self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None |
| 193 | + ) -> torch.Tensor: |
| 194 | + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out |
| 195 | + if infer_state.use_dynamic_prompt_cache: |
| 196 | + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] |
| 197 | + kv = kv.unsqueeze(1) |
| 198 | + infer_state.prefill_wrapper.run( |
| 199 | + q.view(q.shape[0], -1, self.head_dim_), |
| 200 | + (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), |
| 201 | + out=o_tensor.view(q.shape[0], -1, self.head_dim_), |
| 202 | + ) |
| 203 | + else: |
| 204 | + infer_state.prefill_wrapper.run( |
| 205 | + q.view(q.shape[0], -1, self.head_dim_), |
| 206 | + kv[:, : self.tp_k_head_num_, :], |
| 207 | + kv[:, self.tp_k_head_num_ :, :], |
| 208 | + out=o_tensor.view(q.shape[0], -1, self.head_dim_), |
| 209 | + ) |
| 210 | + |
| 211 | + return o_tensor |
| 212 | + |
181 | 213 | def _context_attention_kernel( |
182 | 214 | self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None |
183 | 215 | ) -> torch.Tensor: |
@@ -254,7 +286,6 @@ def _context_attention_kernel_ppl_int8kv( |
254 | 286 | return o_tensor |
255 | 287 |
|
256 | 288 | def _context_attention_flashattention(self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None): |
257 | | - |
258 | 289 | cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( |
259 | 290 | -1, 1, self.tp_k_head_num_, self.head_dim_ |
260 | 291 | ) |
@@ -392,6 +423,19 @@ def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager): |
392 | 423 | ) |
393 | 424 | return |
394 | 425 |
|
| 426 | + def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): |
| 427 | + batch_size = infer_state.batch_size |
| 428 | + calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) |
| 429 | + |
| 430 | + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out |
| 431 | + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) |
| 432 | + infer_state.decode_wrapper.run( |
| 433 | + q.view(calcu_shape1), |
| 434 | + (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), |
| 435 | + out=o_tensor.view(calcu_shape1), |
| 436 | + ) |
| 437 | + return o_tensor |
| 438 | + |
395 | 439 | def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): |
396 | 440 | total_token_num = infer_state.total_token_num |
397 | 441 | batch_size = infer_state.batch_size |
@@ -673,7 +717,6 @@ def _token_decode_attention_gqa_flashdecoding_vsm( |
673 | 717 | ) |
674 | 718 |
|
675 | 719 | def _token_decode_attention_flashattention(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): |
676 | | - |
677 | 720 | cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( |
678 | 721 | -1, 1, self.tp_k_head_num_, self.head_dim_ |
679 | 722 | ) |
@@ -711,7 +754,6 @@ def overlap_tpsp_token_forward( |
711 | 754 | infer_state1: LlamaInferStateInfo, |
712 | 755 | layer_weight: LlamaTransformerLayerWeight, |
713 | 756 | ): |
714 | | - |
715 | 757 | input_embdings = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight) |
716 | 758 | input_embdings1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) |
717 | 759 | return input_embdings, input_embdings1 |
|
0 commit comments