Skip to content

Commit 6ede09e

Browse files
authored
udpate mla decode attention. (#696)
1 parent 02effd7 commit 6ede09e

File tree

3 files changed

+75
-39
lines changed

3 files changed

+75
-39
lines changed

lightllm/common/basemodel/cuda_graph.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import torch
3+
import copy
34
from lightllm.utils.log_utils import init_logger
45
from lightllm.distributed import custom_comm_ops
56

@@ -27,10 +28,18 @@ def capture_decode(self, decode_func, input_ids, infer_state):
2728
infer_state.max_len_in_batch = self.graph_max_len_in_batch
2829
infer_state.total_token_num = self.graph_max_len_in_batch * batch_size
2930
# warmup
31+
# 因为有些推理过程的代码,会通过判断infer_state中是否存在某些属性来在一层上
32+
# 做一些初始化的操作,后续层可以复用这些计算的结果,如
33+
# lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py
34+
# 中做的一些操作,所以在 warmup 的时候,需要调用infer_state的copy函数做一个
35+
# 浅拷贝,不然后续传入到cuda graph捕获过程中后,infer_state因为提前拥有了这些属性,
36+
# 导致不会重新初始化,这样捕获过程中会不能捕获这些临时添加到 infer_state 管理对象
37+
# 中的 tensor。
3038
for _ in range(1):
3139
torch.cuda.synchronize()
32-
decode_func(input_ids, infer_state)
40+
decode_func(input_ids, copy.copy(infer_state)) # infer_state must copy()
3341
torch.cuda.synchronize()
42+
3443
with custom_comm_ops.lightllm_capture_graph():
3544
with torch.cuda.graph(graph_obj, pool=self.mempool):
3645
predict_logics = decode_func(input_ids, infer_state)

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
import torch
33
import torch.multiprocessing as mp
4+
import triton
5+
import triton.language as tl
46
from typing import List
57
from lightllm.utils.log_utils import init_logger
68
from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig
@@ -44,27 +46,19 @@ def gqa_token_decode_attention_flash_decoding(
4446
out_dtype=torch.bfloat16,
4547
)
4648

49+
BLOCK_N = run_config["BLOCK_N"]
50+
4751
from .gqa_flash_decoding_stage1 import flash_decode_stage1
4852
from .gqa_flash_decoding_stage2 import flash_decode_stage2
4953

5054
o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out
5155

52-
mid_o_block_seq = torch.empty([1], dtype=torch.int64, device="cuda")
53-
mid_o_batch_start_index = alloc_tensor_func(
54-
[
55-
batch_size,
56-
],
57-
dtype=torch.int64,
58-
device="cuda",
59-
)
60-
56+
fake_decode_att_block_seq = torch.empty([0], dtype=torch.int64, device="cuda")
6157
mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device="cuda")
6258
mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda")
6359

6460
vsm_count = flash_decode_stage1(
65-
infer_state.total_token_num_tensor,
66-
mid_o_block_seq,
67-
mid_o_batch_start_index,
61+
fake_decode_att_block_seq,
6862
q_nope.view(calcu_shape1),
6963
q_rope.view(calcu_shape2),
7064
kv_nope,
@@ -79,13 +73,40 @@ def gqa_token_decode_attention_flash_decoding(
7973
**run_config
8074
)
8175

76+
if not hasattr(infer_state, "decode_att_block_seq"):
77+
assert batch_size <= 2048
78+
decode_att_block_seq = torch.empty(
79+
[
80+
1,
81+
],
82+
dtype=torch.int64,
83+
device="cuda",
84+
)
85+
mid_o_batch_start_index = torch.empty(
86+
[
87+
batch_size,
88+
],
89+
dtype=torch.int64,
90+
device="cuda",
91+
)
92+
_fwd_kernel_calcu_index_and_block_seq[(1,)](
93+
infer_state.b_seq_len,
94+
decode_att_block_seq,
95+
mid_o_batch_start_index,
96+
vsm_count,
97+
batch_size,
98+
BLOCK_N=BLOCK_N,
99+
num_warps=4,
100+
)
101+
102+
infer_state.decode_att_block_seq = decode_att_block_seq
103+
infer_state.mid_o_batch_start_index = mid_o_batch_start_index
104+
82105
mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device="cuda")
83106
mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device="cuda")
84107

85108
flash_decode_stage1(
86-
infer_state.total_token_num_tensor,
87-
mid_o_block_seq,
88-
mid_o_batch_start_index,
109+
infer_state.decode_att_block_seq,
89110
q_nope.view(calcu_shape1),
90111
q_rope.view(calcu_shape2),
91112
kv_nope,
@@ -101,12 +122,35 @@ def gqa_token_decode_attention_flash_decoding(
101122
)
102123

103124
flash_decode_stage2(
104-
mid_o_block_seq,
105-
mid_o_batch_start_index,
125+
infer_state.decode_att_block_seq,
126+
infer_state.mid_o_batch_start_index,
106127
mid_o,
107128
mid_o_logexpsum,
108129
infer_state.b_seq_len,
109130
o_tensor.view(calcu_shape1),
110131
**run_config
111132
)
112133
return o_tensor
134+
135+
136+
@triton.jit
137+
def _fwd_kernel_calcu_index_and_block_seq(
138+
b_seq_len_ptr,
139+
mid_o_decode_att_block_seq_ptr,
140+
mid_o_batch_start_index_ptr,
141+
num_sm,
142+
batch_size,
143+
BLOCK_N: tl.constexpr,
144+
):
145+
b_seq_len = tl.load(b_seq_len_ptr + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0)
146+
total_token_num = tl.sum(b_seq_len)
147+
148+
block_seq = tl.cast(total_token_num / (num_sm * 4), dtype=tl.int32) + 1
149+
block_seq = tl.cdiv(block_seq, BLOCK_N) * BLOCK_N
150+
151+
block_seq_len = tl.cdiv(b_seq_len, block_seq)
152+
cumsum_seq_len = tl.cumsum(block_seq_len)
153+
batch_start_index = cumsum_seq_len - block_seq_len
154+
tl.store(mid_o_batch_start_index_ptr + tl.arange(0, 2048), batch_start_index, mask=tl.arange(0, 2048) < batch_size)
155+
tl.store(mid_o_decode_att_block_seq_ptr, block_seq)
156+
return

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ def _fwd_kernel_flash_decode_stage1_padding(
3535
stride_mid_od,
3636
stride_mid_o_eh,
3737
stride_mid_o_es,
38-
total_token_ptr,
3938
block_size_ptr,
40-
batch_start_index_ptr,
4139
num_sm,
4240
head_group_num,
4341
head_num,
@@ -51,15 +49,9 @@ def _fwd_kernel_flash_decode_stage1_padding(
5149
):
5250
# cur_kv_head = 0
5351
sm_id = tl.program_id(0).to(tl.int64)
54-
grid_id = sm_id
5552
out_batch_start_index = tl.cast(0, tl.int64)
56-
total_token_num = tl.load(total_token_ptr, eviction_policy="evict_last")
53+
block_seq = tl.load(block_size_ptr, eviction_policy="evict_last")
5754

58-
block_seq = tl.cast(total_token_num / num_sm / 4, dtype=tl.int32) + 1
59-
block_seq = tl.cdiv(block_seq, BLOCK_N) * BLOCK_N
60-
61-
if grid_id == 0:
62-
tl.store(block_size_ptr, block_seq)
6355
cur_q_head_offs = tl.arange(0, Q_HEAD_NUM)
6456
offs_d = tl.arange(0, BLOCK_DMODEL)
6557
offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL)
@@ -163,19 +155,14 @@ def _fwd_kernel_flash_decode_stage1_padding(
163155
)
164156
sm_id += num_sm
165157

166-
if grid_id == 0:
167-
tl.store(batch_start_index_ptr + cur_batch, out_batch_start_index)
168-
169158
out_batch_start_index += cur_block_num // head_group_num
170159
sm_id -= cur_block_num
171160
return
172161

173162

174163
@torch.no_grad()
175164
def flash_decode_stage1(
176-
total_token_num_tensor: torch.Tensor,
177-
out_block_seq: torch.Tensor,
178-
batch_start_index: torch.Tensor,
165+
in_block_seq: torch.Tensor,
179166
q_nope,
180167
q_rope,
181168
kv_nope,
@@ -227,9 +214,7 @@ def flash_decode_stage1(
227214
*kv_rope.stride(),
228215
*mid_out.stride(),
229216
*mid_out_logsumexp.stride(),
230-
total_token_num_tensor,
231-
out_block_seq,
232-
batch_start_index,
217+
in_block_seq,
233218
num_sm=1,
234219
head_group_num=head_group_num,
235220
head_num=q_head_num,
@@ -271,9 +256,7 @@ def flash_decode_stage1(
271256
*kv_rope.stride(),
272257
*mid_out.stride(),
273258
*mid_out_logsumexp.stride(),
274-
total_token_num_tensor,
275-
out_block_seq,
276-
batch_start_index,
259+
in_block_seq,
277260
num_sm=num_sm,
278261
head_group_num=head_group_num,
279262
head_num=q_head_num,

0 commit comments

Comments
 (0)