Skip to content

Commit 24d6fc4

Browse files
author
niushengxiao
committed
feat: add flashinfer prefilled operator in the attention module
1 parent ca1a105 commit 24d6fc4

File tree

3 files changed

+110
-54
lines changed

3 files changed

+110
-54
lines changed

lightllm/models/deepseek2/infer_struct.py

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
import numpy as np
44
import torch.distributed as dist
55
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6-
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
7-
import flashinfer
86

97

108
class Deepseek2InferStateInfo(LlamaInferStateInfo):
119
def __init__(self):
1210
super().__init__()
1311
self.kv_starts = None
12+
self.wrapper = None
1413
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
14+
self.enable_flashinfer_prefilled = os.getenv("ENABLE_FLASHINFER_PREFILLED", "False").upper() in [
15+
"ON",
16+
"TRUE",
17+
"1",
18+
]
1519
self.enable_flashinfer_decode_mla = os.getenv("ENABLE_FLASHINFER_DECODE_MLA", "False").upper() in [
1620
"ON",
1721
"TRUE",
@@ -20,12 +24,24 @@ def __init__(self):
2024

2125
def init_some_extra_state(self, model, input_ids: torch.Tensor):
2226
super().init_some_extra_state(model, input_ids)
23-
# 只有 decode 阶段使用 ppl 的优化算子才会有这个管理变量
27+
2428
if not self.is_prefill:
2529
self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
2630
self.total_token_num_tensor = torch.sum(self.b_seq_len)
2731
if self.enable_flashinfer_decode_mla:
28-
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(input_ids.device)
32+
import flashinfer
33+
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
34+
35+
self.tp_q_head_num = (
36+
model.tp_q_head_num_ * model.world_size_ if self.enable_dp else model.tp_q_head_num_
37+
)
38+
self.kv_lora_rank = model.kv_lora_rank
39+
self.qk_rope_head_dim = model.qk_rope_head_dim
40+
self.qk_nope_head_dim = model.qk_nope_head_dim
41+
self.softmax_scale = model.softmax_scale
42+
self.q_data_type = model.data_type
43+
self.kv_data_type = model.data_type
44+
2945
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
3046
self.kv_indices = torch.empty(self.batch_size * model.max_seq_length, dtype=torch.int32).to(
3147
input_ids.device
@@ -38,27 +54,23 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3854
self.max_len_in_batch,
3955
self.kv_indices,
4056
)
41-
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
42-
self.workspace_buffer,
43-
backend="fa2",
44-
use_cuda_graph=True,
45-
qo_indptr=self.q_indptr,
46-
kv_indices=self.kv_indices,
47-
kv_indptr=self.kv_starts,
48-
kv_len_arr=self.b_seq_len,
49-
)
50-
self.head_num = model.tp_q_head_num_ * model.world_size_ if self.enable_dp else model.tp_q_head_num_
51-
self.kv_lora_rank = model.kv_lora_rank
52-
self.qk_rope_head_dim = model.qk_rope_head_dim
53-
self.softmax_scale = model.softmax_scale
54-
self.q_data_type = model.data_type
55-
self.kv_data_type = model.data_type
57+
if not self.wrapper:
58+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(input_ids.device)
59+
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
60+
workspace_buffer,
61+
backend="fa2",
62+
use_cuda_graph=True,
63+
qo_indptr=self.q_indptr,
64+
kv_indices=self.kv_indices,
65+
kv_indptr=self.kv_starts,
66+
kv_len_arr=self.b_seq_len,
67+
)
5668
self.wrapper.plan(
5769
self.q_indptr,
5870
self.kv_starts,
5971
self.kv_indices,
6072
self.b_seq_len,
61-
self.head_num,
73+
self.tp_q_head_num,
6274
self.kv_lora_rank,
6375
self.qk_rope_head_dim,
6476
1,
@@ -67,9 +79,39 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
6779
self.q_data_type,
6880
self.kv_data_type,
6981
)
70-
71-
if self.is_prefill:
82+
else:
7283
self.b_kv_start_loc = self.b_seq_len.cumsum(dim=0) - self.b_seq_len
84+
if self.enable_flashinfer_prefilled:
85+
import flashinfer
86+
87+
self.tp_q_head_num = (
88+
model.tp_q_head_num_ * model.world_size_ if self.enable_dp else model.tp_q_head_num_
89+
)
90+
self.qk_rope_head_dim = model.qk_rope_head_dim
91+
self.qk_nope_head_dim = model.qk_nope_head_dim
92+
self.softmax_scale = model.softmax_scale
93+
self.q_data_type = model.data_type
94+
95+
q_starts = torch.cat(
96+
[self.b_start_loc, self.b_start_loc[-1:] + (self.b_seq_len - self.b_ready_cache_len)[-1:]], dim=0
97+
).int()
98+
kv_starts = torch.cat(
99+
[self.b_kv_start_loc, self.b_kv_start_loc[-1:] + self.b_seq_len[-1:]], dim=0
100+
).int()
101+
if not self.wrapper:
102+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
103+
self.wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, "NHD")
104+
self.wrapper.plan(
105+
qo_indptr=q_starts,
106+
kv_indptr=kv_starts,
107+
num_qo_heads=self.tp_q_head_num,
108+
num_kv_heads=self.tp_q_head_num,
109+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
110+
head_dim_vo=self.qk_nope_head_dim,
111+
q_data_type=self.q_data_type,
112+
causal=True,
113+
sm_scale=self.softmax_scale,
114+
)
73115

74116
if self.enable_dp:
75117
rank = dist.get_rank()
@@ -95,7 +137,7 @@ def copy_for_cuda_graph(self, new_infer_state):
95137
self.kv_starts,
96138
self.kv_indices,
97139
self.b_seq_len,
98-
self.head_num,
140+
self.tp_q_head_num,
99141
self.kv_lora_rank,
100142
self.qk_rope_head_dim,
101143
1,

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6868
self.num_heads = network_config["num_attention_heads"]
6969
self.num_kv_heads = network_config["num_key_value_heads"]
7070
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
71+
self.enable_flashinfer_prefilled = os.getenv("ENABLE_FLASHINFER_PREFILLED", "False").upper() in [
72+
"ON",
73+
"TRUE",
74+
"1",
75+
]
7176
self.enable_flashinfer_decode_mla = os.getenv("ENABLE_FLASHINFER_DECODE_MLA", "False").upper() in [
7277
"ON",
7378
"TRUE",
@@ -223,22 +228,28 @@ def _context_attention_kernel_with_CC(
223228
out=None,
224229
) -> torch.Tensor:
225230
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
226-
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
227-
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
228-
context_attention_fwd_with_v(
229-
q_nope,
230-
q_rope,
231-
k_nope,
232-
k_rope,
233-
v,
234-
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
235-
infer_state.b_start_loc,
236-
infer_state.b_kv_start_loc,
237-
infer_state.b_seq_len,
238-
infer_state.b_ready_cache_len,
239-
infer_state.max_len_in_batch,
240-
self.softmax_scale,
231+
o_tensor = (
232+
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
241233
)
234+
if self.enable_flashinfer_prefilled:
235+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
236+
infer_state.wrapper.run(q, k, v, out=o_tensor)
237+
else:
238+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
239+
context_attention_fwd_with_v(
240+
q_nope,
241+
q_rope,
242+
k_nope,
243+
k_rope,
244+
v,
245+
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
246+
infer_state.b_start_loc,
247+
infer_state.b_kv_start_loc,
248+
infer_state.b_seq_len,
249+
infer_state.b_ready_cache_len,
250+
infer_state.max_len_in_batch,
251+
self.softmax_scale,
252+
)
242253
return o_tensor
243254

244255
def _context_attention_kernel_with_CC_fp8(
@@ -252,20 +263,24 @@ def _context_attention_kernel_with_CC_fp8(
252263
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True)
253264
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
254265
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
255-
context_attention_fwd_with_v(
256-
q_nope,
257-
q_rope,
258-
k_nope,
259-
k_rope,
260-
v,
261-
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
262-
infer_state.b_start_loc,
263-
infer_state.b_kv_start_loc,
264-
infer_state.b_seq_len,
265-
infer_state.b_ready_cache_len,
266-
infer_state.max_len_in_batch,
267-
self.softmax_scale,
268-
)
266+
if self.enable_flashinfer_prefilled:
267+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
268+
infer_state.wrapper.run(q, k, v, out=o_tensor)
269+
else:
270+
context_attention_fwd_with_v(
271+
q_nope,
272+
q_rope,
273+
k_nope,
274+
k_rope,
275+
v,
276+
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
277+
infer_state.b_start_loc,
278+
infer_state.b_kv_start_loc,
279+
infer_state.b_seq_len,
280+
infer_state.b_ready_cache_len,
281+
infer_state.max_len_in_batch,
282+
self.softmax_scale,
283+
)
269284
return o_tensor
270285

271286
def _context_attention_kernel_origin(

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ def _fwd_kernel_calcu_index_and_block_seq(
179179
req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX)
180180
b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX
181181
b_start_loc = torch.arange(Z).cuda().int() * N_CTX
182-
b_start_loc[0] = 0
183-
b_req_idx = torch.arange(Z).cuda().int()
182+
b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda()
184183
kv_starts = torch.cat([b_start_loc, b_start_loc[-1:] + b_seq_len[-1:]], dim=0)
185184

186185
o = torch.zeros((Z, H, D_HEAD), dtype=dtype, device="cuda")

0 commit comments

Comments
 (0)