|
20 | 20 | from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd |
21 | 21 |
|
22 | 22 | from lightllm.models.llama.infer_struct import LlamaInferStateInfo |
| 23 | +from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo |
23 | 24 | from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv |
24 | 25 | from lightllm.common.basemodel import TransformerLayerInferTpl |
25 | 26 | from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv |
@@ -68,8 +69,12 @@ def _bind_attention(self): |
68 | 69 | ) |
69 | 70 | self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) |
70 | 71 | return |
71 | | - |
72 | | - self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) |
| 72 | + elif get_env_start_args().enable_flashinfer_prefill: |
| 73 | + self._context_attention_kernel = partial( |
| 74 | + LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self |
| 75 | + ) |
| 76 | + else: |
| 77 | + self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self) |
73 | 78 | if "ppl_int8kv" in self.mode: |
74 | 79 | self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv, self) |
75 | 80 | self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self) |
@@ -119,7 +124,12 @@ def _bind_attention(self): |
119 | 124 | ) |
120 | 125 | self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) |
121 | 126 | else: |
122 | | - self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self) |
| 127 | + if get_env_start_args().enable_flashinfer_decode: |
| 128 | + self._token_attention_kernel = partial( |
| 129 | + LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self |
| 130 | + ) |
| 131 | + else: |
| 132 | + self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self) |
123 | 133 | self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self) |
124 | 134 |
|
125 | 135 | return |
@@ -178,6 +188,28 @@ def _tpsp_get_qkv( |
178 | 188 | ) |
179 | 189 | return q, cache_kv |
180 | 190 |
|
| 191 | + def _context_attention_flashinfer_kernel( |
| 192 | + self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None |
| 193 | + ) -> torch.Tensor: |
| 194 | + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out |
| 195 | + if infer_state.use_dynamic_prompt_cache: |
| 196 | + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] |
| 197 | + kv = kv.unsqueeze(1) |
| 198 | + infer_state.prefill_wrapper.run( |
| 199 | + q.view(q.shape[0], -1, self.head_dim_), |
| 200 | + (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), |
| 201 | + out=o_tensor.view(q.shape[0], -1, self.head_dim_), |
| 202 | + ) |
| 203 | + else: |
| 204 | + infer_state.prefill_wrapper.run( |
| 205 | + q.view(q.shape[0], -1, self.head_dim_), |
| 206 | + kv[:, : self.tp_k_head_num_, :], |
| 207 | + kv[:, self.tp_k_head_num_ :, :], |
| 208 | + out=o_tensor.view(q.shape[0], -1, self.head_dim_), |
| 209 | + ) |
| 210 | + |
| 211 | + return o_tensor |
| 212 | + |
181 | 213 | def _context_attention_kernel( |
182 | 214 | self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None |
183 | 215 | ) -> torch.Tensor: |
@@ -392,6 +424,19 @@ def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager): |
392 | 424 | ) |
393 | 425 | return |
394 | 426 |
|
| 427 | + def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None): |
| 428 | + batch_size = infer_state.batch_size |
| 429 | + calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) |
| 430 | + |
| 431 | + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out |
| 432 | + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) |
| 433 | + infer_state.decode_wrapper.run( |
| 434 | + q.view(calcu_shape1), |
| 435 | + (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), |
| 436 | + out=o_tensor.view(calcu_shape1), |
| 437 | + ) |
| 438 | + return o_tensor |
| 439 | + |
395 | 440 | def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): |
396 | 441 | total_token_num = infer_state.total_token_num |
397 | 442 | batch_size = infer_state.batch_size |
@@ -565,7 +610,7 @@ def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo, |
565 | 610 | # at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) |
566 | 611 | fp16_decode_attention( |
567 | 612 | o_tensor.view(calcu_shape1), |
568 | | - 1.0 / (self.head_dim_ ** 0.5), |
| 613 | + 1.0 / (self.head_dim_**0.5), |
569 | 614 | q.view(calcu_shape1), |
570 | 615 | infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], |
571 | 616 | infer_state.mem_manager.kv_buffer[self.layer_num_][ |
@@ -711,7 +756,6 @@ def overlap_tpsp_token_forward( |
711 | 756 | infer_state1: LlamaInferStateInfo, |
712 | 757 | layer_weight: LlamaTransformerLayerWeight, |
713 | 758 | ): |
714 | | - |
715 | 759 | input_embdings = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight) |
716 | 760 | input_embdings1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight) |
717 | 761 | return input_embdings, input_embdings1 |
|
0 commit comments