11import unittest
22import random
33import torch
4+ from tqdm import tqdm
45from lightllm .common .basemodel .infer_struct import InferStateInfo
56from lightllm .common .req_manager import ReqManager
67from lightllm .models .llama .triton_kernel .gqa_flash_decoding_vsm import (
@@ -20,81 +21,83 @@ def test_vsm_gqa_decoding_align(self):
2021 torch .backends .cudnn .deterministic = True
2122 torch .backends .cudnn .benchmark = False
2223
23- bs_list = range ( 1 , 40 )
24+ bs_list = [ 1 , 8 , 16 , 32 , 64 , 128 , 256 ]
2425 group_size_list = [16 , 32 , 64 ]
25- seq_len_list = [128 , 512 , 1024 , 2048 ]
26+ seq_len_list = [128 , 512 , 1024 , 2048 , 4096 , 8192 ]
2627 q_head_dim_list = [64 , 128 ]
27- q_head_num_list = [16 , 32 , 64 ]
28+ q_head_num_list = [8 , 16 , 32 ]
2829
29- for bs in bs_list :
30- for group_size in group_size_list :
31- for seq_len_m in seq_len_list :
32- for q_head_dim in q_head_dim_list :
33- for q_head_num in q_head_num_list :
34- if q_head_num < group_size :
35- continue
36- kv_head_num = q_head_num // group_size
37- q_head_dim = q_head_dim
38- kv_head_dim = q_head_dim
39- seq_len = (torch .zeros (bs , dtype = torch .int32 ) + seq_len_m ).to (torch .int32 )
40- total_token_in_the_batch = seq_len .sum ().item ()
41- rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1 ) // 128 * 128
30+ def get_test_configs ():
31+ for bs in bs_list :
32+ for group_size in group_size_list :
33+ for seq_len_m in seq_len_list :
34+ for q_head_dim in q_head_dim_list :
35+ for q_head_num in q_head_num_list :
36+ if q_head_num < group_size :
37+ continue
38+ yield bs , group_size , seq_len_m , q_head_dim , q_head_num
4239
43- q_shape = [ bs , q_head_num , q_head_dim ]
44- kv_shape = [
45- rounded_total_token_in_the_batch ,
46- kv_head_num ,
47- kv_head_dim ,
48- ]
49- qkv_dtype = torch . float16
40+ for bs , group_size , seq_len_m , q_head_dim , q_head_num in tqdm ( list ( get_test_configs ())):
41+ kv_head_num = q_head_num // group_size
42+ q_head_dim = q_head_dim
43+ kv_head_dim = q_head_dim
44+ seq_len = ( torch . zeros ( bs , dtype = torch . int32 ) + seq_len_m ). to ( torch . int32 )
45+ total_token_in_the_batch = seq_len . sum (). item ()
46+ rounded_total_token_in_the_batch = ( total_token_in_the_batch + 128 - 1 ) // 128 * 128
5047
51- q , k , v = (
52- torch .randn (q_shape , dtype = qkv_dtype , device = "cuda" ),
53- torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ),
54- torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ),
55- )
56- q , k , v = q / 10 , k / 10 , v / 10
48+ q_shape = [bs , q_head_num , q_head_dim ]
49+ kv_shape = [
50+ rounded_total_token_in_the_batch ,
51+ kv_head_num ,
52+ kv_head_dim ,
53+ ]
54+ qkv_dtype = torch .float16
5755
58- req_to_token_index = torch .zeros ((bs , 2048 )) - 1
59- token_index = torch .arange (rounded_total_token_in_the_batch )
56+ q , k , v = (
57+ torch .randn (q_shape , dtype = qkv_dtype , device = "cuda" ),
58+ torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ),
59+ torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ),
60+ )
61+ q , k , v = q / 10 , k / 10 , v / 10
6062
61- total_count = 0
62- for i in range (bs ):
63- req_to_token_index [i , : seq_len [i ]] = token_index [
64- total_count : total_count + seq_len [i ]
65- ]
66- total_count += seq_len [i ]
63+ req_to_token_index = torch .zeros ((bs , seq_len_m )) - 1
64+ token_index = torch .arange (rounded_total_token_in_the_batch )
6765
68- req_to_token_index = req_to_token_index .long ().cuda ()
66+ total_count = 0
67+ for i in range (bs ):
68+ req_to_token_index [i , : seq_len [i ]] = token_index [total_count : total_count + seq_len [i ]]
69+ total_count += seq_len [i ]
6970
70- b_req_idx = torch .arange (bs , device = "cuda" )
71- infer_state = InferStateInfo ()
72- infer_state .req_manager = ReqManager (bs , 2048 , None )
73- infer_state .req_manager .req_to_token_indexs = req_to_token_index
74- infer_state .b_req_idx = b_req_idx .cuda ()
75- infer_state .b_seq_len = seq_len .cuda ()
76- infer_state .max_len_in_batch = 2048
77- infer_state .batch_size = bs
78- infer_state .q_head_num = q_head_num
79- infer_state .q_head_dim = q_head_dim
80- infer_state .kv_head_num = kv_head_num
81- infer_state .softmax_scale = 1 / (q_head_dim ** 0.5 )
82- infer_state .total_token_num = torch .tensor (
83- [total_token_in_the_batch ], dtype = torch .int32
84- ).cuda ()
85- new_out = gqa_token_decode_attention_flash_decoding_vsm (q , k , v , infer_state )
86- old_out = gqa_token_decode_attention_flash_decoding (
87- q ,
88- infer_state ,
89- infer_state .q_head_num ,
90- infer_state .q_head_dim ,
91- k ,
92- v ,
93- )
94- cos_sim = (
95- torch .nn .functional .cosine_similarity (new_out , old_out , dim = - 1 ).mean ().cpu ().item ()
96- )
97- self .assertGreaterEqual (cos_sim , 0.99 )
71+ req_to_token_index = req_to_token_index .long ().cuda ()
72+
73+ b_req_idx = torch .arange (bs , device = "cuda" )
74+ infer_state = InferStateInfo ()
75+ infer_state .req_manager = ReqManager (bs , 2048 , None )
76+ infer_state .req_manager .req_to_token_indexs = req_to_token_index
77+ infer_state .b_req_idx = b_req_idx .cuda ()
78+ infer_state .b_seq_len = seq_len .cuda ()
79+ infer_state .max_len_in_batch = seq_len_m
80+ infer_state .batch_size = bs
81+ infer_state .q_head_num = q_head_num
82+ infer_state .q_head_dim = q_head_dim
83+ infer_state .kv_head_num = kv_head_num
84+ infer_state .softmax_scale = 1 / (q_head_dim ** 0.5 )
85+ infer_state .total_token_num = torch .tensor ([total_token_in_the_batch ], dtype = torch .int32 ).cuda ()
86+ new_out = gqa_token_decode_attention_flash_decoding_vsm (q , k , v , infer_state )
87+ old_out = gqa_token_decode_attention_flash_decoding (
88+ q ,
89+ infer_state ,
90+ infer_state .q_head_num ,
91+ infer_state .q_head_dim ,
92+ k ,
93+ v ,
94+ )
95+ cos_sim = torch .nn .functional .cosine_similarity (new_out , old_out , dim = - 1 ).mean ().cpu ().item ()
96+ self .assertGreaterEqual (
97+ cos_sim ,
98+ 0.9 ,
99+ f"bs={ bs } , group_size={ group_size } , seq_len={ seq_len_m } , q_head_dim={ q_head_dim } , q_head_num={ q_head_num } " ,
100+ )
98101
99102
100103if __name__ == "__main__" :
0 commit comments