Skip to content

Commit 1b9f4f4

Browse files
committed
build: add tunning for llama vsm
1 parent 18ea040 commit 1b9f4f4

File tree

4 files changed

+370
-68
lines changed

4 files changed

+370
-68
lines changed

lightllm/models/llama/triton_kernel/gqa_flash_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def gqa_token_decode_attention_flash_decoding(
1212
from .gqa_flash_decoding_stage1 import flash_decode_stage1
1313
from .gqa_flash_decoding_stage2 import flash_decode_stage2
1414

15-
o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out
15+
o_tensor = alloc_tensor_func(q.shape, dtype=q.dtype, device=q.device) if out is None else out
1616

1717
mid_o = alloc_tensor_func(
1818
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda"

lightllm/models/llama/triton_kernel/gqa_flash_decoding_vsm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def gqa_token_decode_attention_flash_decoding_vsm_stage1(
286286
*mid_o.stride(),
287287
*mid_o_logexpsum.stride(),
288288
BLOCK_N=run_config["BLOCK_N"],
289-
Q_HEAD_NUM=triton.next_power_of_2(gqa_group_size),
289+
Q_HEAD_NUM=max(16, triton.next_power_of_2(gqa_group_size)),
290290
BLOCK_DMODEL=q.shape[-1],
291291
NUM_STAGES=run_config["stage1_num_stages"],
292292
num_stages=run_config["stage1_num_stages"],

test/kernel/alignment/llama_gqa_decode_vsm.py

Lines changed: 70 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import random
33
import torch
4+
from tqdm import tqdm
45
from lightllm.common.basemodel.infer_struct import InferStateInfo
56
from lightllm.common.req_manager import ReqManager
67
from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import (
@@ -20,81 +21,84 @@ def test_vsm_gqa_decoding_align(self):
2021
torch.backends.cudnn.deterministic = True
2122
torch.backends.cudnn.benchmark = False
2223

23-
bs_list = range(1, 40)
24+
bs_list = [1, 8, 16, 32, 64, 128, 256]
2425
group_size_list = [16, 32, 64]
25-
seq_len_list = [128, 512, 1024, 2048]
26+
seq_len_list = [128, 512, 1024, 2048, 4096, 8192]
2627
q_head_dim_list = [64, 128]
27-
q_head_num_list = [16, 32, 64]
28+
q_head_num_list = [8, 16, 32]
2829

29-
for bs in bs_list:
30-
for group_size in group_size_list:
31-
for seq_len_m in seq_len_list:
32-
for q_head_dim in q_head_dim_list:
33-
for q_head_num in q_head_num_list:
34-
if q_head_num < group_size:
35-
continue
36-
kv_head_num = q_head_num // group_size
37-
q_head_dim = q_head_dim
38-
kv_head_dim = q_head_dim
39-
seq_len = (torch.zeros(bs, dtype=torch.int32) + seq_len_m).to(torch.int32)
40-
total_token_in_the_batch = seq_len.sum().item()
41-
rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1) // 128 * 128
30+
def get_test_configs():
31+
for bs in bs_list:
32+
for group_size in group_size_list:
33+
for seq_len_m in seq_len_list:
34+
for q_head_dim in q_head_dim_list:
35+
for q_head_num in q_head_num_list:
36+
if q_head_num < group_size:
37+
continue
38+
yield bs, group_size, seq_len_m, q_head_dim, q_head_num
39+
for bs, group_size, seq_len_m, q_head_dim, q_head_num in tqdm(list(get_test_configs())):
40+
kv_head_num = q_head_num // group_size
41+
q_head_dim = q_head_dim
42+
kv_head_dim = q_head_dim
43+
seq_len = (torch.zeros(bs, dtype=torch.int32) + seq_len_m).to(torch.int32)
44+
total_token_in_the_batch = seq_len.sum().item()
45+
rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1) // 128 * 128
4246

43-
q_shape = [bs, q_head_num, q_head_dim]
44-
kv_shape = [
45-
rounded_total_token_in_the_batch,
46-
kv_head_num,
47-
kv_head_dim,
48-
]
49-
qkv_dtype = torch.float16
47+
q_shape = [bs, q_head_num, q_head_dim]
48+
kv_shape = [
49+
rounded_total_token_in_the_batch,
50+
kv_head_num,
51+
kv_head_dim,
52+
]
53+
qkv_dtype = torch.float16
5054

51-
q, k, v = (
52-
torch.randn(q_shape, dtype=qkv_dtype, device="cuda"),
53-
torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"),
54-
torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"),
55-
)
56-
q, k, v = q / 10, k / 10, v / 10
55+
q, k, v = (
56+
torch.randn(q_shape, dtype=qkv_dtype, device="cuda"),
57+
torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"),
58+
torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"),
59+
)
60+
q, k, v = q / 10, k / 10, v / 10
5761

58-
req_to_token_index = torch.zeros((bs, 2048)) - 1
59-
token_index = torch.arange(rounded_total_token_in_the_batch)
62+
req_to_token_index = torch.zeros((bs, seq_len_m)) - 1
63+
token_index = torch.arange(rounded_total_token_in_the_batch)
6064

61-
total_count = 0
62-
for i in range(bs):
63-
req_to_token_index[i, : seq_len[i]] = token_index[
64-
total_count : total_count + seq_len[i]
65-
]
66-
total_count += seq_len[i]
65+
total_count = 0
66+
for i in range(bs):
67+
req_to_token_index[i, : seq_len[i]] = token_index[
68+
total_count : total_count + seq_len[i]
69+
]
70+
total_count += seq_len[i]
6771

68-
req_to_token_index = req_to_token_index.long().cuda()
72+
req_to_token_index = req_to_token_index.long().cuda()
6973

70-
b_req_idx = torch.arange(bs, device="cuda")
71-
infer_state = InferStateInfo()
72-
infer_state.req_manager = ReqManager(bs, 2048, None)
73-
infer_state.req_manager.req_to_token_indexs = req_to_token_index
74-
infer_state.b_req_idx = b_req_idx.cuda()
75-
infer_state.b_seq_len = seq_len.cuda()
76-
infer_state.max_len_in_batch = 2048
77-
infer_state.batch_size = bs
78-
infer_state.q_head_num = q_head_num
79-
infer_state.q_head_dim = q_head_dim
80-
infer_state.kv_head_num = kv_head_num
81-
infer_state.softmax_scale = 1 / (q_head_dim ** 0.5)
82-
infer_state.total_token_num = torch.tensor(
83-
[total_token_in_the_batch], dtype=torch.int32
84-
).cuda()
85-
new_out = gqa_token_decode_attention_flash_decoding_vsm(q, k, v, infer_state)
86-
old_out = gqa_token_decode_attention_flash_decoding(
87-
q,
88-
infer_state,
89-
infer_state.q_head_num,
90-
infer_state.q_head_dim,
91-
k,
92-
v,
93-
)
94-
cos_sim = (
95-
torch.nn.functional.cosine_similarity(new_out, old_out, dim=-1).mean().cpu().item()
96-
)
97-
self.assertGreaterEqual(cos_sim, 0.99)
74+
b_req_idx = torch.arange(bs, device="cuda")
75+
infer_state = InferStateInfo()
76+
infer_state.req_manager = ReqManager(bs, 2048, None)
77+
infer_state.req_manager.req_to_token_indexs = req_to_token_index
78+
infer_state.b_req_idx = b_req_idx.cuda()
79+
infer_state.b_seq_len = seq_len.cuda()
80+
infer_state.max_len_in_batch = seq_len_m
81+
infer_state.batch_size = bs
82+
infer_state.q_head_num = q_head_num
83+
infer_state.q_head_dim = q_head_dim
84+
infer_state.kv_head_num = kv_head_num
85+
infer_state.softmax_scale = 1 / (q_head_dim ** 0.5)
86+
infer_state.total_token_num = torch.tensor(
87+
[total_token_in_the_batch], dtype=torch.int32
88+
).cuda()
89+
new_out = gqa_token_decode_attention_flash_decoding_vsm(q, k, v, infer_state)
90+
old_out = gqa_token_decode_attention_flash_decoding(
91+
q,
92+
infer_state,
93+
infer_state.q_head_num,
94+
infer_state.q_head_dim,
95+
k,
96+
v,
97+
)
98+
cos_sim = (
99+
torch.nn.functional.cosine_similarity(new_out, old_out, dim=-1).mean().cpu().item()
100+
)
101+
self.assertGreaterEqual(cos_sim, 0.9, f'bs={bs}, group_size={group_size}, seq_len={seq_len_m}, q_head_dim={q_head_dim}, q_head_num={q_head_num}')
98102

99103

100104
if __name__ == "__main__":

0 commit comments

Comments
 (0)