Skip to content

Commit e25cccb

Browse files
author
niushengxiao
committed
feat: add flashinfer backend for llama
1 parent 082a408 commit e25cccb

File tree

6 files changed

+554
-6
lines changed

6 files changed

+554
-6
lines changed

lightllm/models/deepseek2/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
logger = init_logger(__name__)
2121

2222

23-
class FlashInferStateExtraInfo:
23+
class DeepSeek2FlashInferStateExtraInfo:
2424
def __init__(self, model):
2525
num_heads = model.config["num_attention_heads"]
2626
self.tp_q_head_num = num_heads // get_dp_world_size()
@@ -84,6 +84,8 @@ def _init_some_value(self):
8484
self.q_lora_rank = self.config["q_lora_rank"]
8585
self.kv_lora_rank = self.config["kv_lora_rank"]
8686
self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim
87+
if self.enable_flashinfer:
88+
self.flashinfer_extra_state = DeepSeek2FlashInferStateExtraInfo(self)
8789

8890
def _init_custom(self):
8991
self._init_to_get_yarn_rotary()
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
import torch.distributed as dist
5+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6+
from lightllm.utils.envs_utils import get_env_start_args
7+
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
8+
9+
10+
class LlamaFlashInferStateInfo(LlamaInferStateInfo):
11+
def __init__(self):
12+
super().__init__()
13+
self.prefill_wrapper = None
14+
self.decode_wrapper = None
15+
self.flashinfer_extra_state = None
16+
17+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
18+
super().init_some_extra_state(model, input_ids)
19+
self.flashinfer_extra_state = model.flashinfer_extra_state
20+
21+
import flashinfer
22+
23+
if not self.is_prefill:
24+
if get_env_start_args().enable_flashinfer_decode:
25+
self.kv_last_page_len_buffer = torch.full((self.batch_size,), 1, dtype=torch.int32).to(input_ids.device)
26+
self.kv_indices = torch.empty(
27+
self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32
28+
).to(input_ids.device)
29+
repack_kv_index(
30+
self.req_manager.req_to_token_indexs,
31+
self.b_req_idx,
32+
self.b_seq_len,
33+
self.b_start_loc,
34+
self.max_len_in_batch,
35+
self.kv_indices,
36+
)
37+
self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
38+
if self.decode_wrapper is None:
39+
self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
40+
self.flashinfer_extra_state.workspace_buffer,
41+
"NHD",
42+
use_cuda_graph=True,
43+
use_tensor_cores=True,
44+
paged_kv_indptr_buffer=self.kv_starts,
45+
paged_kv_indices_buffer=self.kv_indices,
46+
paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer,
47+
)
48+
self.decode_wrapper.plan(
49+
self.kv_starts,
50+
self.kv_indices,
51+
self.kv_last_page_len_buffer,
52+
self.flashinfer_extra_state.tp_q_head_num,
53+
self.flashinfer_extra_state.tp_kv_head_num,
54+
self.flashinfer_extra_state.head_dim,
55+
1,
56+
q_data_type=self.flashinfer_extra_state.q_data_type,
57+
kv_data_type=self.flashinfer_extra_state.kv_data_type,
58+
non_blocking=True,
59+
)
60+
else:
61+
if get_env_start_args().enable_flashinfer_prefill:
62+
q_starts = torch.zeros((self.batch_size + 1,)).int().cuda()
63+
q_starts[1:] = torch.cumsum(self.b_seq_len - self.b_ready_cache_len, dim=0)
64+
kv_starts = torch.zeros_like(q_starts)
65+
kv_starts[1:] = torch.cumsum(self.b_seq_len, dim=0)
66+
kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32).to(input_ids.device)
67+
if self.use_dynamic_prompt_cache:
68+
kv_indices = torch.empty(
69+
self.batch_size * self.flashinfer_extra_state.max_seq_length, dtype=torch.int32
70+
).to(input_ids.device)
71+
repack_kv_index(
72+
self.req_manager.req_to_token_indexs,
73+
self.b_req_idx,
74+
self.b_seq_len,
75+
self.b_start_loc,
76+
self.max_len_in_batch,
77+
kv_indices,
78+
)
79+
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
80+
self.flashinfer_extra_state.workspace_buffer,
81+
qo_indptr_buf=q_starts,
82+
paged_kv_indptr_buf=kv_starts,
83+
paged_kv_indices_buf=kv_indices,
84+
paged_kv_last_page_len_buf=kv_last_page_len,
85+
)
86+
self.prefill_wrapper.plan(
87+
q_starts,
88+
kv_starts,
89+
kv_indices,
90+
kv_last_page_len,
91+
self.flashinfer_extra_state.tp_q_head_num,
92+
self.flashinfer_extra_state.tp_kv_head_num,
93+
self.flashinfer_extra_state.head_dim,
94+
1,
95+
causal=True,
96+
pos_encoding_mode="NONE",
97+
logits_soft_cap=0.0,
98+
q_data_type=self.flashinfer_extra_state.q_data_type,
99+
kv_data_type=self.flashinfer_extra_state.kv_data_type,
100+
)
101+
else:
102+
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
103+
self.flashinfer_extra_state.workspace_buffer,
104+
)
105+
self.prefill_wrapper.plan(
106+
qo_indptr=q_starts,
107+
kv_indptr=kv_starts,
108+
num_qo_heads=self.flashinfer_extra_state.tp_q_head_num,
109+
num_kv_heads=self.flashinfer_extra_state.tp_kv_head_num,
110+
head_dim_qk=self.flashinfer_extra_state.head_dim,
111+
head_dim_vo=self.flashinfer_extra_state.head_dim,
112+
causal=True,
113+
q_data_type=self.flashinfer_extra_state.q_data_type,
114+
kv_data_type=self.flashinfer_extra_state.kv_data_type,
115+
)
116+
return
117+
118+
def copy_for_cuda_graph(self, new_infer_state):
119+
super().copy_for_cuda_graph(new_infer_state)
120+
if get_env_start_args().enable_flashinfer_decode and not self.is_prefill:
121+
self.decode_wrapper.plan(
122+
new_infer_state.kv_starts,
123+
new_infer_state.kv_indices,
124+
new_infer_state.kv_last_page_len_buffer,
125+
new_infer_state.flashinfer_extra_state.tp_q_head_num,
126+
new_infer_state.flashinfer_extra_state.tp_kv_head_num,
127+
new_infer_state.flashinfer_extra_state.head_dim,
128+
1,
129+
q_data_type=new_infer_state.flashinfer_extra_state.q_data_type,
130+
kv_data_type=new_infer_state.flashinfer_extra_state.kv_data_type,
131+
non_blocking=True,
132+
)
133+
return

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
2121

2222
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
23+
from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo
2324
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv, destindex_copy_quantize_kv
2425
from lightllm.common.basemodel import TransformerLayerInferTpl
2526
from lightllm.models.llama.triton_kernel.ppl_quant_copy_kv import destindex_copy_dequantize_kv
@@ -68,8 +69,12 @@ def _bind_attention(self):
6869
)
6970
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
7071
return
71-
72-
self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self)
72+
elif get_env_start_args().enable_flashinfer_prefill:
73+
self._context_attention_kernel = partial(
74+
LlamaTransformerLayerInfer._context_attention_flashinfer_kernel, self
75+
)
76+
else:
77+
self._context_attention_kernel = partial(LlamaTransformerLayerInfer._context_attention_kernel, self)
7378
if "ppl_int8kv" in self.mode:
7479
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv, self)
7580
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self)
@@ -119,7 +124,12 @@ def _bind_attention(self):
119124
)
120125
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
121126
else:
122-
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self)
127+
if get_env_start_args().enable_flashinfer_decode:
128+
self._token_attention_kernel = partial(
129+
LlamaTransformerLayerInfer._token_decode_attention_flashinfer, self
130+
)
131+
else:
132+
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self)
123133
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
124134

125135
return
@@ -178,6 +188,28 @@ def _tpsp_get_qkv(
178188
)
179189
return q, cache_kv
180190

191+
def _context_attention_flashinfer_kernel(
192+
self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None
193+
) -> torch.Tensor:
194+
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
195+
if infer_state.use_dynamic_prompt_cache:
196+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
197+
kv = kv.unsqueeze(1)
198+
infer_state.prefill_wrapper.run(
199+
q.view(q.shape[0], -1, self.head_dim_),
200+
(kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]),
201+
out=o_tensor.view(q.shape[0], -1, self.head_dim_),
202+
)
203+
else:
204+
infer_state.prefill_wrapper.run(
205+
q.view(q.shape[0], -1, self.head_dim_),
206+
kv[:, : self.tp_k_head_num_, :],
207+
kv[:, self.tp_k_head_num_ :, :],
208+
out=o_tensor.view(q.shape[0], -1, self.head_dim_),
209+
)
210+
211+
return o_tensor
212+
181213
def _context_attention_kernel(
182214
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None
183215
) -> torch.Tensor:
@@ -392,6 +424,19 @@ def _copy_kv_to_mem_cache_ppl_int4kv(self, buffer, mem_index, mem_manager):
392424
)
393425
return
394426

427+
def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None):
428+
batch_size = infer_state.batch_size
429+
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)
430+
431+
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
432+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1)
433+
infer_state.decode_wrapper.run(
434+
q.view(calcu_shape1),
435+
(kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]),
436+
out=o_tensor.view(calcu_shape1),
437+
)
438+
return o_tensor
439+
395440
def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
396441
total_token_num = infer_state.total_token_num
397442
batch_size = infer_state.batch_size
@@ -565,7 +610,7 @@ def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo,
565610
# at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch)
566611
fp16_decode_attention(
567612
o_tensor.view(calcu_shape1),
568-
1.0 / (self.head_dim_ ** 0.5),
613+
1.0 / (self.head_dim_**0.5),
569614
q.view(calcu_shape1),
570615
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :],
571616
infer_state.mem_manager.kv_buffer[self.layer_num_][
@@ -711,7 +756,6 @@ def overlap_tpsp_token_forward(
711756
infer_state1: LlamaInferStateInfo,
712757
layer_weight: LlamaTransformerLayerWeight,
713758
):
714-
715759
input_embdings = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight)
716760
input_embdings1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight)
717761
return input_embdings, input_embdings1

lightllm/models/llama/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,28 @@
1212

1313
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
1414
from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo
15+
from lightllm.models.llama.flashinfer_struct import LlamaFlashInferStateInfo
1516
from lightllm.common.basemodel import TpPartBaseModel
1617
from lightllm.common.mem_utils import select_mem_manager_class
1718
from lightllm.utils.log_utils import init_logger
1819
from lightllm.utils.envs_utils import get_env_start_args
20+
from lightllm.utils.dist_utils import get_dp_world_size, get_current_device_id
1921

2022
logger = init_logger(__name__)
2123

2224

25+
class LlamaFlashInferStateExtraInfo:
26+
def __init__(self, model):
27+
tp_world_size = get_dp_world_size()
28+
self.tp_q_head_num = model.config["num_attention_heads"] // tp_world_size
29+
self.tp_kv_head_num = model.config["num_key_value_heads"] // tp_world_size
30+
self.head_dim = model.config["hidden_size"] // model.config["num_attention_heads"]
31+
self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(get_current_device_id())
32+
self.max_seq_length = model.max_seq_length
33+
self.q_data_type = model.data_type
34+
self.kv_data_type = model.data_type
35+
36+
2337
class LlamaTpPartModel(TpPartBaseModel):
2438
# weight class
2539
pre_and_post_weight_class = LlamaPreAndPostLayerWeight
@@ -34,6 +48,11 @@ class LlamaTpPartModel(TpPartBaseModel):
3448
infer_state_class = LlamaInferStateInfo
3549

3650
def __init__(self, kvargs):
51+
self.enable_flashinfer = (
52+
get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode
53+
)
54+
if self.enable_flashinfer:
55+
self.infer_state_class = LlamaFlashInferStateInfo
3756
super().__init__(kvargs)
3857
return
3958

@@ -42,6 +61,8 @@ def _init_config(self):
4261
# rename key
4362
# repair_config()
4463
self._reset_num_key_value_heads()
64+
if self.enable_flashinfer:
65+
self.flashinfer_extra_state = LlamaFlashInferStateExtraInfo(self)
4566
return
4667

4768
def _reset_num_key_value_heads(self):

0 commit comments

Comments
 (0)