1-
21import unittest
32import random
43import torch
54from lightllm .common .basemodel .infer_struct import InferStateInfo
65from lightllm .common .req_manager import ReqManager
7- from lightllm .models .llama .triton_kernel .gqa_flash_decoding_vsm import gqa_token_decode_attention_flash_decoding_vsm
8- from lightllm .models .llama .triton_kernel .gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
9- from lightllm .utils .infer_utils import benchmark_time
6+ from lightllm .models .llama .triton_kernel .gqa_flash_decoding_vsm import (
7+ gqa_token_decode_attention_flash_decoding_vsm ,
8+ )
9+ from lightllm .models .llama .triton_kernel .gqa_flash_decoding import (
10+ gqa_token_decode_attention_flash_decoding ,
11+ )
12+
1013
1114class TestVSMGQADecoding (unittest .TestCase ):
12- def test_vsm_gqa_decoding_able_to_run (self ):
13- # set seed
15+ def test_vsm_gqa_decoding_align (self ):
1416 random .seed (0 )
1517 torch .manual_seed (0 )
1618 torch .cuda .manual_seed (0 )
@@ -19,55 +21,81 @@ def test_vsm_gqa_decoding_able_to_run(self):
1921 torch .backends .cudnn .benchmark = False
2022
2123 bs_list = range (1 , 40 )
22- group_size_list = [8 , 16 ]
23- seq_len_list = [256 , 512 , 1024 , 2048 ]
24+ group_size_list = [16 , 32 , 64 ]
25+ seq_len_list = [128 , 512 , 1024 , 2048 ]
2426 q_head_dim_list = [64 , 128 ]
25- q_head_num_list = [8 , 16 , 32 ]
27+ q_head_num_list = [16 , 32 , 64 ]
2628
2729 for bs in bs_list :
2830 for group_size in group_size_list :
29- for seq_len in seq_len_list :
31+ for seq_len_m in seq_len_list :
3032 for q_head_dim in q_head_dim_list :
3133 for q_head_num in q_head_num_list :
3234 if q_head_num < group_size :
3335 continue
3436 kv_head_num = q_head_num // group_size
3537 q_head_dim = q_head_dim
3638 kv_head_dim = q_head_dim
37- seq_len = (torch .zeros (bs , dtype = torch .int32 ) + seq_len ).to (torch .int32 )
39+ seq_len = (torch .zeros (bs , dtype = torch .int32 ) + seq_len_m ).to (torch .int32 )
3840 total_token_in_the_batch = seq_len .sum ().item ()
3941 rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1 ) // 128 * 128
4042
4143 q_shape = [bs , q_head_num , q_head_dim ]
42- kv_shape = [rounded_total_token_in_the_batch , kv_head_num , kv_head_dim ]
44+ kv_shape = [
45+ rounded_total_token_in_the_batch ,
46+ kv_head_num ,
47+ kv_head_dim ,
48+ ]
4349 qkv_dtype = torch .float16
4450
45- q , k , v = torch .randn (q_shape , dtype = qkv_dtype , device = "cuda" ), torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ), torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" )
51+ q , k , v = (
52+ torch .randn (q_shape , dtype = qkv_dtype , device = "cuda" ),
53+ torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ),
54+ torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ),
55+ )
56+ q , k , v = q / 10 , k / 10 , v / 10
4657
4758 req_to_token_index = torch .zeros ((bs , 2048 )) - 1
4859 token_index = torch .arange (rounded_total_token_in_the_batch )
49-
5060
5161 total_count = 0
5262 for i in range (bs ):
53- req_to_token_index [i , :seq_len [i ]] = token_index [total_count :total_count + seq_len [i ]]
63+ req_to_token_index [i , : seq_len [i ]] = token_index [
64+ total_count : total_count + seq_len [i ]
65+ ]
5466 total_count += seq_len [i ]
5567
5668 req_to_token_index = req_to_token_index .long ().cuda ()
57-
69+
5870 b_req_idx = torch .arange (bs , device = "cuda" )
59- state = InferStateInfo ()
60- state .req_manager = ReqManager (bs , 2048 , None )
61- state .b_req_idx = b_req_idx .cuda ()
62- state .b_seq_len = seq_len .cuda ()
63- state .max_len_in_batch = 2048
64- state .batch_size = bs
65- state .q_head_num = q_head_num
66- state .q_head_dim = q_head_dim
67- state .kv_head_num = kv_head_num
68- state .softmax_scale = 1 / (q_head_dim ** 0.5 )
69- state .total_token_num = torch .tensor ([total_token_in_the_batch ], dtype = torch .int32 ).cuda ()
70- benchmark_time (gqa_token_decode_attention_flash_decoding_vsm , q , k , v , state , warmup = 0 , repeat = 1 )
71+ infer_state = InferStateInfo ()
72+ infer_state .req_manager = ReqManager (bs , 2048 , None )
73+ infer_state .req_manager .req_to_token_indexs = req_to_token_index
74+ infer_state .b_req_idx = b_req_idx .cuda ()
75+ infer_state .b_seq_len = seq_len .cuda ()
76+ infer_state .max_len_in_batch = 2048
77+ infer_state .batch_size = bs
78+ infer_state .q_head_num = q_head_num
79+ infer_state .q_head_dim = q_head_dim
80+ infer_state .kv_head_num = kv_head_num
81+ infer_state .softmax_scale = 1 / (q_head_dim ** 0.5 )
82+ infer_state .total_token_num = torch .tensor (
83+ [total_token_in_the_batch ], dtype = torch .int32
84+ ).cuda ()
85+ new_out = gqa_token_decode_attention_flash_decoding_vsm (q , k , v , infer_state )
86+ old_out = gqa_token_decode_attention_flash_decoding (
87+ q ,
88+ infer_state ,
89+ infer_state .q_head_num ,
90+ infer_state .q_head_dim ,
91+ k ,
92+ v ,
93+ )
94+ cos_sim = (
95+ torch .nn .functional .cosine_similarity (new_out , old_out , dim = - 1 ).mean ().cpu ().item ()
96+ )
97+ self .assertGreaterEqual (cos_sim , 0.99 )
98+
7199
72100if __name__ == "__main__" :
73- unittest .main ()
101+ unittest .main ()
0 commit comments