Skip to content

Commit a9c2428

Browse files
committed
refactor: move tunning scripts to a new dir
build: add tunning for llama vsm refactor: move deepseek tunning to new dir
1 parent 5a3edef commit a9c2428

File tree

5 files changed

+412
-90
lines changed

5 files changed

+412
-90
lines changed

lightllm/models/llama/triton_kernel/gqa_flash_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def gqa_token_decode_attention_flash_decoding(
1212
from .gqa_flash_decoding_stage1 import flash_decode_stage1
1313
from .gqa_flash_decoding_stage2 import flash_decode_stage2
1414

15-
o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out
15+
o_tensor = alloc_tensor_func(q.shape, dtype=q.dtype, device=q.device) if out is None else out
1616

1717
mid_o = alloc_tensor_func(
1818
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda"

lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ def try_to_get_best_config(
4444
return config
4545
else:
4646
config = {
47-
"BLOCK_N": 16,
47+
"BLOCK_N": 64,
4848
"BLOCK_Q_HEAD": 16,
4949
"stage1_num_warps": 4,
5050
"stage1_num_stages": 2,
5151
"stage2_num_warps": 4,
52-
"stage2_num_stages": 2,
52+
"stage2_num_stages": 1,
5353
}
5454
return config
5555

@@ -150,38 +150,45 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1(
150150
mid_o_logexpsum: [q_head_num, total_seq_block_num]
151151
"""
152152
sm_id = tl.program_id(0).to(tl.int64)
153-
block_size = tl.load(block_size, eviction_policy="evict_last")
153+
block_size = tl.load(block_size)
154154

155155
out_batch_start_index = tl.cast(0, tl.int64)
156156
q_head_off = tl.arange(0, Q_HEAD_NUM)
157157
d_off = tl.arange(0, BLOCK_DMODEL)
158158

159-
for cur_batch in tl.range(0, batch_size, 1):
160-
cur_req_idx = tl.load(b_req_idx + cur_batch, eviction_policy="evict_last")
161-
cur_seq_len = tl.load(b_seq_len + cur_batch, eviction_policy="evict_last")
159+
for cur_batch in range(0, batch_size):
160+
cur_req_idx = tl.load(b_req_idx + cur_batch)
161+
cur_seq_len = tl.load(b_seq_len + cur_batch)
162162

163163
cur_num_of_blocks = tl.cdiv(cur_seq_len, block_size)
164164
cur_num_of_kv_head_pairs = cur_num_of_blocks * kv_head_num
165165

166-
loop_sm_id = sm_id
167-
while loop_sm_id < cur_num_of_kv_head_pairs:
168-
cur_block_idx = loop_sm_id // kv_head_num
169-
cur_kv_head_idx = loop_sm_id % kv_head_num
166+
# loop_sm_id = sm_id
167+
while sm_id < cur_num_of_kv_head_pairs:
168+
cur_block_idx = sm_id % cur_num_of_blocks
169+
cur_kv_head_idx = sm_id // cur_num_of_blocks
170+
# cur_block_idx = sm_id // kv_head_num
171+
# cur_kv_head_idx = sm_id % kv_head_num
170172

171-
cur_q_start = cur_kv_head_idx * gqa_group_size
172-
cur_q_range = cur_q_start + q_head_off
173+
cur_q_range = cur_kv_head_idx * gqa_group_size + q_head_off
173174
cur_q_mask = q_head_off < gqa_group_size
174-
q_off = cur_batch * stride_q_bs + cur_q_range[:, None] * stride_q_h + d_off[None, :]
175-
q_tensor = tl.load(q + q_off, mask=cur_q_mask[:, None], other=0.0) # shape: [Q_HEAD_NUM, BLOCK_DMODEL]
176175

177176
cur_kv_start = cur_block_idx * block_size
178-
cur_kv_end = tl.minimum(cur_kv_start + block_size, cur_seq_len)
177+
178+
q_off = cur_batch * stride_q_bs + cur_q_range[:, None] * stride_q_h + d_off[None, :]
179+
q_tensor = tl.load(
180+
q + q_off,
181+
mask=cur_q_mask[:, None],
182+
other=0.0,
183+
) # shape: [Q_HEAD_NUM, BLOCK_DMODEL]
179184

180185
sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32)
181186
max_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf")
182187
accumu = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32)
183188

184-
for chunk_idx in tl.range(0, tl.cdiv(cur_kv_end - cur_kv_start, BLOCK_N), 1, num_stages=NUM_STAGES):
189+
cur_total_chunk = tl.cdiv(tl.minimum(cur_kv_start + block_size, cur_seq_len) - cur_kv_start, BLOCK_N)
190+
191+
for chunk_idx in tl.range(0, cur_total_chunk, 1, num_stages=NUM_STAGES):
185192
cur_chunk_start = cur_kv_start + chunk_idx * BLOCK_N
186193
cur_chunk_range = cur_chunk_start + tl.arange(0, BLOCK_N)
187194
cur_chunk_mask = cur_chunk_range < cur_seq_len
@@ -196,10 +203,10 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1(
196203
k_off = (
197204
cur_kv_loc[None, :] * stride_k_bs + cur_kv_head_idx * stride_k_h + d_off[:, None]
198205
) # shape: [BLOCK_DMODEL, BLOCK_N]
206+
v_off = cur_kv_loc[:, None] * stride_v_bs + cur_kv_head_idx * stride_v_h + d_off[None, :]
199207
k_tensor = tl.load(k + k_off, mask=cur_chunk_mask[None, :], other=0.0)
208+
200209
att_tensor = tl.dot(q_tensor, k_tensor) # shape: [Q_HEAD_NUM, BLOCK_N]
201-
v_off = cur_kv_loc[:, None] * stride_v_bs + cur_kv_head_idx * stride_v_h + d_off[None, :]
202-
v_tensor = tl.load(v + v_off, mask=cur_chunk_mask[:, None], other=0.0) # shape: [BLOCK_N, BLOCK_DMODEL]
203210
att_tensor *= softmax_scale
204211
att_tensor = tl.where(cur_chunk_mask[None, :], att_tensor, float("-inf"))
205212

@@ -209,7 +216,8 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1(
209216
exp_logic = tl.exp(att_tensor - new_max[:, None])
210217
log_scale = tl.exp(max_exp - new_max)
211218
accumu *= log_scale[:, None]
212-
accumu += tl.dot(exp_logic, v_tensor.to(accumu.dtype))
219+
v_tensor = tl.load(v + v_off, mask=cur_chunk_mask[:, None], other=0.0) # shape: [BLOCK_N, BLOCK_DMODEL]
220+
accumu += tl.dot(exp_logic.to(v_tensor.dtype), v_tensor)
213221

214222
sum_exp = sum_exp * log_scale + tl.sum(exp_logic, axis=1)
215223
max_exp = new_max
@@ -223,12 +231,14 @@ def _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1(
223231
cur_q_range * stride_mid_o_logexpsum_h
224232
+ (out_batch_start_index + cur_block_idx) * stride_mid_o_logexpsum_seq
225233
)
234+
max_exp = max_exp + tl.log(sum_exp)
226235
tl.store(
227236
mid_o_logexpsum + off_mid_o_logexpsum,
228-
max_exp + tl.log(sum_exp),
237+
max_exp,
229238
mask=cur_q_mask,
230239
)
231-
loop_sm_id += num_sm
240+
sm_id += num_sm
241+
sm_id -= cur_num_of_kv_head_pairs
232242
out_batch_start_index += cur_num_of_blocks
233243

234244

@@ -276,7 +286,7 @@ def gqa_token_decode_attention_flash_decoding_vsm_stage1(
276286
*mid_o.stride(),
277287
*mid_o_logexpsum.stride(),
278288
BLOCK_N=run_config["BLOCK_N"],
279-
Q_HEAD_NUM=max(run_config["BLOCK_Q_HEAD"], triton.next_power_of_2(q_head_num)),
289+
Q_HEAD_NUM=max(16, triton.next_power_of_2(gqa_group_size)),
280290
BLOCK_DMODEL=q.shape[-1],
281291
NUM_STAGES=run_config["stage1_num_stages"],
282292
num_stages=run_config["stage1_num_stages"],
@@ -424,7 +434,7 @@ def gqa_token_decode_attention_flash_decoding_vsm(
424434
out_dtype=q.dtype,
425435
)
426436

427-
if not out:
437+
if out is None:
428438
out = alloc_tensor_func(q.shape, dtype=q.dtype, device=q.device)
429439

430440
num_vsm = emstimate_stage1_vsm(

test/kernel/alignment/llama_gqa_decode_vsm.py

Lines changed: 69 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import random
33
import torch
4+
from tqdm import tqdm
45
from lightllm.common.basemodel.infer_struct import InferStateInfo
56
from lightllm.common.req_manager import ReqManager
67
from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import (
@@ -20,81 +21,83 @@ def test_vsm_gqa_decoding_align(self):
2021
torch.backends.cudnn.deterministic = True
2122
torch.backends.cudnn.benchmark = False
2223

23-
bs_list = range(1, 40)
24+
bs_list = [1, 8, 16, 32, 64, 128, 256]
2425
group_size_list = [16, 32, 64]
25-
seq_len_list = [128, 512, 1024, 2048]
26+
seq_len_list = [128, 512, 1024, 2048, 4096, 8192]
2627
q_head_dim_list = [64, 128]
27-
q_head_num_list = [16, 32, 64]
28+
q_head_num_list = [8, 16, 32]
2829

29-
for bs in bs_list:
30-
for group_size in group_size_list:
31-
for seq_len_m in seq_len_list:
32-
for q_head_dim in q_head_dim_list:
33-
for q_head_num in q_head_num_list:
34-
if q_head_num < group_size:
35-
continue
36-
kv_head_num = q_head_num // group_size
37-
q_head_dim = q_head_dim
38-
kv_head_dim = q_head_dim
39-
seq_len = (torch.zeros(bs, dtype=torch.int32) + seq_len_m).to(torch.int32)
40-
total_token_in_the_batch = seq_len.sum().item()
41-
rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1) // 128 * 128
30+
def get_test_configs():
31+
for bs in bs_list:
32+
for group_size in group_size_list:
33+
for seq_len_m in seq_len_list:
34+
for q_head_dim in q_head_dim_list:
35+
for q_head_num in q_head_num_list:
36+
if q_head_num < group_size:
37+
continue
38+
yield bs, group_size, seq_len_m, q_head_dim, q_head_num
4239

43-
q_shape = [bs, q_head_num, q_head_dim]
44-
kv_shape = [
45-
rounded_total_token_in_the_batch,
46-
kv_head_num,
47-
kv_head_dim,
48-
]
49-
qkv_dtype = torch.float16
40+
for bs, group_size, seq_len_m, q_head_dim, q_head_num in tqdm(list(get_test_configs())):
41+
kv_head_num = q_head_num // group_size
42+
q_head_dim = q_head_dim
43+
kv_head_dim = q_head_dim
44+
seq_len = (torch.zeros(bs, dtype=torch.int32) + seq_len_m).to(torch.int32)
45+
total_token_in_the_batch = seq_len.sum().item()
46+
rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1) // 128 * 128
5047

51-
q, k, v = (
52-
torch.randn(q_shape, dtype=qkv_dtype, device="cuda"),
53-
torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"),
54-
torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"),
55-
)
56-
q, k, v = q / 10, k / 10, v / 10
48+
q_shape = [bs, q_head_num, q_head_dim]
49+
kv_shape = [
50+
rounded_total_token_in_the_batch,
51+
kv_head_num,
52+
kv_head_dim,
53+
]
54+
qkv_dtype = torch.float16
5755

58-
req_to_token_index = torch.zeros((bs, 2048)) - 1
59-
token_index = torch.arange(rounded_total_token_in_the_batch)
56+
q, k, v = (
57+
torch.randn(q_shape, dtype=qkv_dtype, device="cuda"),
58+
torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"),
59+
torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"),
60+
)
61+
q, k, v = q / 10, k / 10, v / 10
6062

61-
total_count = 0
62-
for i in range(bs):
63-
req_to_token_index[i, : seq_len[i]] = token_index[
64-
total_count : total_count + seq_len[i]
65-
]
66-
total_count += seq_len[i]
63+
req_to_token_index = torch.zeros((bs, seq_len_m)) - 1
64+
token_index = torch.arange(rounded_total_token_in_the_batch)
6765

68-
req_to_token_index = req_to_token_index.long().cuda()
66+
total_count = 0
67+
for i in range(bs):
68+
req_to_token_index[i, : seq_len[i]] = token_index[total_count : total_count + seq_len[i]]
69+
total_count += seq_len[i]
6970

70-
b_req_idx = torch.arange(bs, device="cuda")
71-
infer_state = InferStateInfo()
72-
infer_state.req_manager = ReqManager(bs, 2048, None)
73-
infer_state.req_manager.req_to_token_indexs = req_to_token_index
74-
infer_state.b_req_idx = b_req_idx.cuda()
75-
infer_state.b_seq_len = seq_len.cuda()
76-
infer_state.max_len_in_batch = 2048
77-
infer_state.batch_size = bs
78-
infer_state.q_head_num = q_head_num
79-
infer_state.q_head_dim = q_head_dim
80-
infer_state.kv_head_num = kv_head_num
81-
infer_state.softmax_scale = 1 / (q_head_dim ** 0.5)
82-
infer_state.total_token_num = torch.tensor(
83-
[total_token_in_the_batch], dtype=torch.int32
84-
).cuda()
85-
new_out = gqa_token_decode_attention_flash_decoding_vsm(q, k, v, infer_state)
86-
old_out = gqa_token_decode_attention_flash_decoding(
87-
q,
88-
infer_state,
89-
infer_state.q_head_num,
90-
infer_state.q_head_dim,
91-
k,
92-
v,
93-
)
94-
cos_sim = (
95-
torch.nn.functional.cosine_similarity(new_out, old_out, dim=-1).mean().cpu().item()
96-
)
97-
self.assertGreaterEqual(cos_sim, 0.99)
71+
req_to_token_index = req_to_token_index.long().cuda()
72+
73+
b_req_idx = torch.arange(bs, device="cuda")
74+
infer_state = InferStateInfo()
75+
infer_state.req_manager = ReqManager(bs, 2048, None)
76+
infer_state.req_manager.req_to_token_indexs = req_to_token_index
77+
infer_state.b_req_idx = b_req_idx.cuda()
78+
infer_state.b_seq_len = seq_len.cuda()
79+
infer_state.max_len_in_batch = seq_len_m
80+
infer_state.batch_size = bs
81+
infer_state.q_head_num = q_head_num
82+
infer_state.q_head_dim = q_head_dim
83+
infer_state.kv_head_num = kv_head_num
84+
infer_state.softmax_scale = 1 / (q_head_dim ** 0.5)
85+
infer_state.total_token_num = torch.tensor([total_token_in_the_batch], dtype=torch.int32).cuda()
86+
new_out = gqa_token_decode_attention_flash_decoding_vsm(q, k, v, infer_state)
87+
old_out = gqa_token_decode_attention_flash_decoding(
88+
q,
89+
infer_state,
90+
infer_state.q_head_num,
91+
infer_state.q_head_dim,
92+
k,
93+
v,
94+
)
95+
cos_sim = torch.nn.functional.cosine_similarity(new_out, old_out, dim=-1).mean().cpu().item()
96+
self.assertGreaterEqual(
97+
cos_sim,
98+
0.9,
99+
f"bs={bs}, group_size={group_size}, seq_len={seq_len_m}, q_head_dim={q_head_dim}, q_head_num={q_head_num}",
100+
)
98101

99102

100103
if __name__ == "__main__":
File renamed without changes.

0 commit comments

Comments
 (0)