Skip to content

Commit f64ce3b

Browse files
committed
tune gqa for deepseek v2
1 parent e4ec0c0 commit f64ce3b

File tree

2 files changed

+17
-26
lines changed

2 files changed

+17
-26
lines changed

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def gqa_token_decode_attention_flash_decoding(
1515
out=None,
1616
alloc_tensor_func=torch.empty,
1717
):
18-
BLOCK_SEQ = 128
18+
BLOCK_SEQ = 64
1919
batch_size = infer_state.batch_size
2020
max_len_in_batch = infer_state.max_len_in_batch
2121
calcu_shape1 = (batch_size, q_head_num, kv_lora_rank)
@@ -27,10 +27,10 @@ def gqa_token_decode_attention_flash_decoding(
2727
o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out
2828

2929
mid_o = alloc_tensor_func(
30-
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, kv_lora_rank], dtype=torch.float32, device="cuda"
30+
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, kv_lora_rank], dtype=q_nope.dtype, device="cuda"
3131
)
3232
mid_o_logexpsum = alloc_tensor_func(
33-
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
33+
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=q_nope.dtype, device="cuda"
3434
)
3535

3636
flash_decode_stage1(

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@ def _fwd_kernel_flash_decode_stage1(
3636
stride_mid_o_eb,
3737
stride_mid_o_eh,
3838
stride_mid_o_es,
39-
gqa_group_size,
4039
Q_HEAD_NUM: tl.constexpr,
4140
BLOCK_SEQ: tl.constexpr,
4241
BLOCK_DMODEL: tl.constexpr,
4342
BLOCK_ROPE_DMODEL: tl.constexpr,
4443
BLOCK_N: tl.constexpr,
4544
):
46-
cur_batch = tl.program_id(0)
47-
cur_kv_head = tl.program_id(1)
48-
seq_start_block = tl.program_id(2)
45+
seq_start_block = tl.program_id(0)
46+
cur_q_head = tl.program_id(1)
47+
cur_batch = tl.program_id(2)
4948

5049
cur_q_head_offs = tl.arange(0, Q_HEAD_NUM)
51-
cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs
50+
cur_q_head_range = cur_q_head * Q_HEAD_NUM + cur_q_head_offs
5251

5352
offs_d = tl.arange(0, BLOCK_DMODEL)
5453
offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL)
@@ -59,7 +58,8 @@ def _fwd_kernel_flash_decode_stage1(
5958

6059
off_q = cur_batch * stride_q_bs + cur_q_head_range[:, None] * stride_q_h + offs_d[None, :]
6160
off_rope_q = cur_batch * stride_q_rope_bs + cur_q_head_range[:, None] * stride_q_rope_h + offs_rope_d[None, :]
62-
61+
q = tl.load(Q_nope + off_q)
62+
q_rope = tl.load(Q_rope + off_rope_q)
6363
block_n_size = (
6464
tl.where(
6565
cur_batch_end_index - cur_batch_start_index <= 0,
@@ -70,27 +70,20 @@ def _fwd_kernel_flash_decode_stage1(
7070
)
7171

7272
offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)
73-
74-
q = tl.load(Q_nope + off_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, other=0.0)
75-
q_rope = tl.load(
76-
Q_rope + off_rope_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, other=0.0
77-
)
78-
7973
sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32)
8074
max_logic = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf")
8175
acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32)
82-
8376
for start_n in range(0, block_n_size, 1):
8477
offs_n_new = start_n * BLOCK_N + offs_n
8578
kv_loc = tl.load(
8679
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
8780
mask=offs_n_new < cur_batch_end_index,
8881
other=0,
8982
)
90-
off_kv = kv_loc[None, :] * stride_kv_bs + cur_kv_head * stride_kv_h + offs_d[:, None]
83+
off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None]
9184
kv = tl.load(KV_nope + off_kv, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0)
9285
att_value = tl.dot(q, kv)
93-
off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + cur_kv_head * stride_kv_rope_h + offs_rope_d[:, None]
86+
off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None]
9487
rope_kv = tl.load(KV_rope + off_rope_kv, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0)
9588
att_value += tl.dot(q_rope, rope_kv)
9689

@@ -120,12 +113,10 @@ def _fwd_kernel_flash_decode_stage1(
120113
tl.store(
121114
Mid_O + off_mid_o,
122115
acc / sum_exp[:, None],
123-
mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size,
124116
)
125117
tl.store(
126118
Mid_O_LogExpSum + off_mid_o_logexpsum,
127119
max_logic + tl.log(sum_exp),
128-
mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size,
129120
)
130121
return
131122

@@ -147,6 +138,7 @@ def flash_decode_stage1(
147138
):
148139
BLOCK_SEQ = block_seq
149140
BLOCK_N = 16
141+
BLOCK_Q_HEAD = 16
150142
assert BLOCK_SEQ % BLOCK_N == 0
151143
# shape constraints
152144
q_nope_dim = q_nope.shape[-1]
@@ -158,9 +150,9 @@ def flash_decode_stage1(
158150
assert q_rope_dim in {16, 32, 64, 128, 256}
159151

160152
sm_scale = softmax_scale # 计算scale系数
161-
batch, kv_head_num = B_req_idx.shape[0], kv_nope.shape[1]
162-
grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ))
163-
gqa_group_size = q_nope.shape[1] // kv_nope.shape[1]
153+
batch, q_head_num = B_req_idx.shape[0], q_nope.shape[1]
154+
assert q_head_num % BLOCK_Q_HEAD == 0
155+
grid = (triton.cdiv(max_len_in_batch, BLOCK_SEQ), q_head_num // BLOCK_Q_HEAD, batch)
164156

165157
_fwd_kernel_flash_decode_stage1[grid](
166158
q_nope,
@@ -194,13 +186,12 @@ def flash_decode_stage1(
194186
mid_out_logsumexp.stride(0),
195187
mid_out_logsumexp.stride(1),
196188
mid_out_logsumexp.stride(2),
197-
gqa_group_size,
198-
Q_HEAD_NUM=max(16, triton.next_power_of_2(gqa_group_size)),
189+
Q_HEAD_NUM=q_head_num,
199190
BLOCK_SEQ=BLOCK_SEQ,
200191
BLOCK_DMODEL=q_nope_dim,
201192
BLOCK_ROPE_DMODEL=q_rope_dim,
202193
BLOCK_N=BLOCK_N,
203-
num_warps=2,
194+
num_warps=4,
204195
num_stages=2,
205196
)
206197
return

0 commit comments

Comments
 (0)