Skip to content

Commit fe02a31

Browse files
author
niushengxiao
committed
feat: add flashinfer prefilled operator in the attention module
1 parent 3c51248 commit fe02a31

File tree

3 files changed

+125
-66
lines changed

3 files changed

+125
-66
lines changed

lightllm/models/deepseek2/infer_struct.py

Lines changed: 79 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@
33
import numpy as np
44
import torch.distributed as dist
55
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6-
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
7-
import flashinfer
86

97

108
class Deepseek2InferStateInfo(LlamaInferStateInfo):
119
def __init__(self):
1210
super().__init__()
1311
self.kv_starts = None
12+
self.prefill_wrapper = None
13+
self.decode_wrapper = None
1414
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
15+
self.enable_flashinfer_prefilled = os.getenv("ENABLE_FLASHINFER_PREFILLED", "False").upper() in [
16+
"ON",
17+
"TRUE",
18+
"1",
19+
]
1520
self.enable_flashinfer_decode_mla = os.getenv("ENABLE_FLASHINFER_DECODE_MLA", "False").upper() in [
1621
"ON",
1722
"TRUE",
@@ -20,12 +25,24 @@ def __init__(self):
2025

2126
def init_some_extra_state(self, model, input_ids: torch.Tensor):
2227
super().init_some_extra_state(model, input_ids)
23-
# 只有 decode 阶段使用 ppl 的优化算子才会有这个管理变量
28+
2429
if not self.is_prefill:
2530
self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
2631
self.total_token_num_tensor = torch.sum(self.b_seq_len)
2732
if self.enable_flashinfer_decode_mla:
28-
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(input_ids.device)
33+
import flashinfer
34+
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
35+
36+
self.tp_q_head_num = (
37+
model.tp_q_head_num_ * model.world_size_ if self.enable_dp else model.tp_q_head_num_
38+
)
39+
self.kv_lora_rank = model.kv_lora_rank
40+
self.qk_rope_head_dim = model.qk_rope_head_dim
41+
self.qk_nope_head_dim = model.qk_nope_head_dim
42+
self.softmax_scale = model.softmax_scale
43+
self.q_data_type = model.data_type
44+
self.kv_data_type = model.data_type
45+
2946
self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device)
3047
self.kv_indices = torch.empty(self.batch_size * model.max_seq_length, dtype=torch.int32).to(
3148
input_ids.device
@@ -38,38 +55,66 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3855
self.max_len_in_batch,
3956
self.kv_indices,
4057
)
41-
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
42-
self.workspace_buffer,
43-
backend="fa2",
44-
use_cuda_graph=True,
45-
qo_indptr=self.q_indptr,
46-
kv_indices=self.kv_indices,
47-
kv_indptr=self.kv_starts,
48-
kv_len_arr=self.b_seq_len,
58+
if not self.decode_wrapper:
59+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(input_ids.device)
60+
self.decode_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
61+
workspace_buffer,
62+
backend="fa2",
63+
use_cuda_graph=True,
64+
qo_indptr=self.q_indptr,
65+
kv_indices=self.kv_indices,
66+
kv_indptr=self.kv_starts,
67+
kv_len_arr=self.b_seq_len,
68+
)
69+
self.decode_wrapper.plan(
70+
self.q_indptr,
71+
self.kv_starts,
72+
self.kv_indices,
73+
self.b_seq_len,
74+
self.tp_q_head_num,
75+
self.kv_lora_rank,
76+
self.qk_rope_head_dim,
77+
1,
78+
False, # causal
79+
self.softmax_scale,
80+
self.q_data_type,
81+
self.kv_data_type,
82+
)
83+
else:
84+
self.b_kv_start_loc = self.b_seq_len.cumsum(dim=0) - self.b_seq_len
85+
if self.enable_flashinfer_prefilled:
86+
import flashinfer
87+
88+
self.tp_q_head_num = (
89+
model.tp_q_head_num_ * model.world_size_ if self.enable_dp else model.tp_q_head_num_
4990
)
50-
self.head_num = model.tp_q_head_num_ * model.world_size_ if self.enable_dp else model.tp_q_head_num_
51-
self.kv_lora_rank = model.kv_lora_rank
5291
self.qk_rope_head_dim = model.qk_rope_head_dim
92+
self.qk_nope_head_dim = model.qk_nope_head_dim
5393
self.softmax_scale = model.softmax_scale
5494
self.q_data_type = model.data_type
55-
self.kv_data_type = model.data_type
56-
self.wrapper.plan(
57-
self.q_indptr,
58-
self.kv_starts,
59-
self.kv_indices,
60-
self.b_seq_len,
61-
self.head_num,
62-
self.kv_lora_rank,
63-
self.qk_rope_head_dim,
64-
1,
65-
False, # causal
66-
self.softmax_scale,
67-
self.q_data_type,
68-
self.kv_data_type,
69-
)
7095

71-
if self.is_prefill:
72-
self.b_kv_start_loc = self.b_seq_len.cumsum(dim=0) - self.b_seq_len
96+
q_starts = torch.cat(
97+
[self.b_start_loc, self.b_start_loc[-1:] + (self.b_seq_len - self.b_ready_cache_len)[-1:]], dim=0
98+
).int()
99+
kv_starts = torch.cat(
100+
[self.b_kv_start_loc, self.b_kv_start_loc[-1:] + self.b_seq_len[-1:]], dim=0
101+
).int()
102+
if not self.prefill_wrapper:
103+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
104+
self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
105+
workspace_buffer, "NHD"
106+
)
107+
self.prefill_wrapper.plan(
108+
qo_indptr=q_starts,
109+
kv_indptr=kv_starts,
110+
num_qo_heads=self.tp_q_head_num,
111+
num_kv_heads=self.tp_q_head_num,
112+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
113+
head_dim_vo=self.qk_nope_head_dim,
114+
q_data_type=self.q_data_type,
115+
causal=True,
116+
sm_scale=self.softmax_scale,
117+
)
73118

74119
if self.enable_dp:
75120
rank = dist.get_rank()
@@ -89,13 +134,13 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
89134

90135
def copy_for_cuda_graph(self, new_infer_state):
91136
super().copy_for_cuda_graph(new_infer_state)
92-
if self.enable_flashinfer_decode_mla:
93-
self.wrapper.plan(
137+
if self.enable_flashinfer_decode_mla and not self.is_prefill:
138+
self.decode_wrapper.plan(
94139
self.q_indptr,
95140
self.kv_starts,
96141
self.kv_indices,
97142
self.b_seq_len,
98-
self.head_num,
143+
self.tp_q_head_num,
99144
self.kv_lora_rank,
100145
self.qk_rope_head_dim,
101146
1,

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6969
self.num_heads = network_config["num_attention_heads"]
7070
self.num_kv_heads = network_config["num_key_value_heads"]
7171
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
72+
self.enable_flashinfer_prefilled = os.getenv("ENABLE_FLASHINFER_PREFILLED", "False").upper() in [
73+
"ON",
74+
"TRUE",
75+
"1",
76+
]
7277
self.enable_flashinfer_decode_mla = os.getenv("ENABLE_FLASHINFER_DECODE_MLA", "False").upper() in [
7378
"ON",
7479
"TRUE",
@@ -220,22 +225,28 @@ def _context_attention_kernel_with_CC(
220225
out=None,
221226
) -> torch.Tensor:
222227
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
223-
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
224-
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
225-
context_attention_fwd_with_v(
226-
q_nope,
227-
q_rope,
228-
k_nope,
229-
k_rope,
230-
v,
231-
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
232-
infer_state.b_start_loc,
233-
infer_state.b_kv_start_loc,
234-
infer_state.b_seq_len,
235-
infer_state.b_ready_cache_len,
236-
infer_state.max_len_in_batch,
237-
self.softmax_scale,
228+
o_tensor = (
229+
self.alloc_tensor((q.shape[0], q.shape[1], self.qk_nope_head_dim), dtype=q.dtype) if out is None else out
238230
)
231+
if self.enable_flashinfer_prefilled:
232+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
233+
infer_state.prefill_wrapper.run(q, k, v, out=o_tensor)
234+
else:
235+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
236+
context_attention_fwd_with_v(
237+
q_nope,
238+
q_rope,
239+
k_nope,
240+
k_rope,
241+
v,
242+
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
243+
infer_state.b_start_loc,
244+
infer_state.b_kv_start_loc,
245+
infer_state.b_seq_len,
246+
infer_state.b_ready_cache_len,
247+
infer_state.max_len_in_batch,
248+
self.softmax_scale,
249+
)
239250
return o_tensor
240251

241252
def _context_attention_kernel_with_CC_fp8(
@@ -249,20 +260,24 @@ def _context_attention_kernel_with_CC_fp8(
249260
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True)
250261
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
251262
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
252-
context_attention_fwd_with_v(
253-
q_nope,
254-
q_rope,
255-
k_nope,
256-
k_rope,
257-
v,
258-
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
259-
infer_state.b_start_loc,
260-
infer_state.b_kv_start_loc,
261-
infer_state.b_seq_len,
262-
infer_state.b_ready_cache_len,
263-
infer_state.max_len_in_batch,
264-
self.softmax_scale,
265-
)
263+
if self.enable_flashinfer_prefilled:
264+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
265+
infer_state.prefill_wrapper.run(q, k, v, out=o_tensor)
266+
else:
267+
context_attention_fwd_with_v(
268+
q_nope,
269+
q_rope,
270+
k_nope,
271+
k_rope,
272+
v,
273+
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
274+
infer_state.b_start_loc,
275+
infer_state.b_kv_start_loc,
276+
infer_state.b_seq_len,
277+
infer_state.b_ready_cache_len,
278+
infer_state.max_len_in_batch,
279+
self.softmax_scale,
280+
)
266281
return o_tensor
267282

268283
def _context_attention_kernel_origin(
@@ -378,7 +393,7 @@ def _token_gqa_decode_attention_flashdecoding(
378393
)
379394
return o_tensor
380395
elif self.enable_flashinfer_decode_mla:
381-
infer_state.wrapper.run(
396+
infer_state.decode_wrapper.run(
382397
q_nope,
383398
q_rope,
384399
kv[:, :, : -self.qk_rope_head_dim],

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ def _fwd_kernel_calcu_index_and_block_seq(
179179
req_to_token_indexs = torch.randperm(max_input_len, dtype=torch.int32).cuda().view(Z, N_CTX)
180180
b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX
181181
b_start_loc = torch.arange(Z).cuda().int() * N_CTX
182-
b_start_loc[0] = 0
183-
b_req_idx = torch.arange(Z).cuda().int()
182+
b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda()
184183
kv_starts = torch.cat([b_start_loc, b_start_loc[-1:] + b_seq_len[-1:]], dim=0)
185184

186185
o = torch.zeros((Z, H, D_HEAD), dtype=dtype, device="cuda")

0 commit comments

Comments
 (0)