Skip to content

Commit 42dc583

Browse files
committed
test: add cos align
1 parent f7cd225 commit 42dc583

File tree

1 file changed

+57
-29
lines changed

1 file changed

+57
-29
lines changed
Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
21
import unittest
32
import random
43
import torch
54
from lightllm.common.basemodel.infer_struct import InferStateInfo
65
from lightllm.common.req_manager import ReqManager
7-
from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import gqa_token_decode_attention_flash_decoding_vsm
8-
from lightllm.models.llama.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
9-
from lightllm.utils.infer_utils import benchmark_time
6+
from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import (
7+
gqa_token_decode_attention_flash_decoding_vsm,
8+
)
9+
from lightllm.models.llama.triton_kernel.gqa_flash_decoding import (
10+
gqa_token_decode_attention_flash_decoding,
11+
)
12+
1013

1114
class TestVSMGQADecoding(unittest.TestCase):
12-
def test_vsm_gqa_decoding_able_to_run(self):
13-
# set seed
15+
def test_vsm_gqa_decoding_align(self):
1416
random.seed(0)
1517
torch.manual_seed(0)
1618
torch.cuda.manual_seed(0)
@@ -19,55 +21,81 @@ def test_vsm_gqa_decoding_able_to_run(self):
1921
torch.backends.cudnn.benchmark = False
2022

2123
bs_list = range(1, 40)
22-
group_size_list = [8, 16]
23-
seq_len_list = [256, 512, 1024, 2048]
24+
group_size_list = [16, 32, 64]
25+
seq_len_list = [128, 512, 1024, 2048]
2426
q_head_dim_list = [64, 128]
25-
q_head_num_list = [8, 16, 32]
27+
q_head_num_list = [16, 32, 64]
2628

2729
for bs in bs_list:
2830
for group_size in group_size_list:
29-
for seq_len in seq_len_list:
31+
for seq_len_m in seq_len_list:
3032
for q_head_dim in q_head_dim_list:
3133
for q_head_num in q_head_num_list:
3234
if q_head_num < group_size:
3335
continue
3436
kv_head_num = q_head_num // group_size
3537
q_head_dim = q_head_dim
3638
kv_head_dim = q_head_dim
37-
seq_len = (torch.zeros(bs, dtype=torch.int32) + seq_len).to(torch.int32)
39+
seq_len = (torch.zeros(bs, dtype=torch.int32) + seq_len_m).to(torch.int32)
3840
total_token_in_the_batch = seq_len.sum().item()
3941
rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1) // 128 * 128
4042

4143
q_shape = [bs, q_head_num, q_head_dim]
42-
kv_shape = [rounded_total_token_in_the_batch, kv_head_num, kv_head_dim]
44+
kv_shape = [
45+
rounded_total_token_in_the_batch,
46+
kv_head_num,
47+
kv_head_dim,
48+
]
4349
qkv_dtype = torch.float16
4450

45-
q, k, v = torch.randn(q_shape, dtype=qkv_dtype, device="cuda"), torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"), torch.randn(kv_shape, dtype=qkv_dtype, device="cuda")
51+
q, k, v = (
52+
torch.randn(q_shape, dtype=qkv_dtype, device="cuda"),
53+
torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"),
54+
torch.randn(kv_shape, dtype=qkv_dtype, device="cuda"),
55+
)
56+
q, k, v = q / 10, k / 10, v / 10
4657

4758
req_to_token_index = torch.zeros((bs, 2048)) - 1
4859
token_index = torch.arange(rounded_total_token_in_the_batch)
49-
5060

5161
total_count = 0
5262
for i in range(bs):
53-
req_to_token_index[i, :seq_len[i]] = token_index[total_count:total_count + seq_len[i]]
63+
req_to_token_index[i, : seq_len[i]] = token_index[
64+
total_count : total_count + seq_len[i]
65+
]
5466
total_count += seq_len[i]
5567

5668
req_to_token_index = req_to_token_index.long().cuda()
57-
69+
5870
b_req_idx = torch.arange(bs, device="cuda")
59-
state = InferStateInfo()
60-
state.req_manager = ReqManager(bs, 2048, None)
61-
state.b_req_idx = b_req_idx.cuda()
62-
state.b_seq_len = seq_len.cuda()
63-
state.max_len_in_batch = 2048
64-
state.batch_size = bs
65-
state.q_head_num = q_head_num
66-
state.q_head_dim = q_head_dim
67-
state.kv_head_num = kv_head_num
68-
state.softmax_scale = 1 / (q_head_dim ** 0.5)
69-
state.total_token_num = torch.tensor([total_token_in_the_batch], dtype=torch.int32).cuda()
70-
benchmark_time(gqa_token_decode_attention_flash_decoding_vsm, q, k, v, state, warmup=0, repeat=1)
71+
infer_state = InferStateInfo()
72+
infer_state.req_manager = ReqManager(bs, 2048, None)
73+
infer_state.req_manager.req_to_token_indexs = req_to_token_index
74+
infer_state.b_req_idx = b_req_idx.cuda()
75+
infer_state.b_seq_len = seq_len.cuda()
76+
infer_state.max_len_in_batch = 2048
77+
infer_state.batch_size = bs
78+
infer_state.q_head_num = q_head_num
79+
infer_state.q_head_dim = q_head_dim
80+
infer_state.kv_head_num = kv_head_num
81+
infer_state.softmax_scale = 1 / (q_head_dim ** 0.5)
82+
infer_state.total_token_num = torch.tensor(
83+
[total_token_in_the_batch], dtype=torch.int32
84+
).cuda()
85+
new_out = gqa_token_decode_attention_flash_decoding_vsm(q, k, v, infer_state)
86+
old_out = gqa_token_decode_attention_flash_decoding(
87+
q,
88+
infer_state,
89+
infer_state.q_head_num,
90+
infer_state.q_head_dim,
91+
k,
92+
v,
93+
)
94+
cos_sim = (
95+
torch.nn.functional.cosine_similarity(new_out, old_out, dim=-1).mean().cpu().item()
96+
)
97+
self.assertGreaterEqual(cos_sim, 0.99)
98+
7199

72100
if __name__ == "__main__":
73-
unittest.main()
101+
unittest.main()

0 commit comments

Comments
 (0)