11import unittest
22import random
33import torch
4+ from tqdm import tqdm
45from lightllm .common .basemodel .infer_struct import InferStateInfo
56from lightllm .common .req_manager import ReqManager
67from lightllm .models .llama .triton_kernel .gqa_flash_decoding_vsm import (
@@ -20,81 +21,84 @@ def test_vsm_gqa_decoding_align(self):
2021 torch .backends .cudnn .deterministic = True
2122 torch .backends .cudnn .benchmark = False
2223
23- bs_list = range ( 1 , 40 )
24+ bs_list = [ 1 , 8 , 16 , 32 , 64 , 128 , 256 ]
2425 group_size_list = [16 , 32 , 64 ]
25- seq_len_list = [128 , 512 , 1024 , 2048 ]
26+ seq_len_list = [128 , 512 , 1024 , 2048 , 4096 , 8192 ]
2627 q_head_dim_list = [64 , 128 ]
27- q_head_num_list = [16 , 32 , 64 ]
28+ q_head_num_list = [8 , 16 , 32 ]
2829
29- for bs in bs_list :
30- for group_size in group_size_list :
31- for seq_len_m in seq_len_list :
32- for q_head_dim in q_head_dim_list :
33- for q_head_num in q_head_num_list :
34- if q_head_num < group_size :
35- continue
36- kv_head_num = q_head_num // group_size
37- q_head_dim = q_head_dim
38- kv_head_dim = q_head_dim
39- seq_len = (torch .zeros (bs , dtype = torch .int32 ) + seq_len_m ).to (torch .int32 )
40- total_token_in_the_batch = seq_len .sum ().item ()
41- rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1 ) // 128 * 128
30+ def get_test_configs ():
31+ for bs in bs_list :
32+ for group_size in group_size_list :
33+ for seq_len_m in seq_len_list :
34+ for q_head_dim in q_head_dim_list :
35+ for q_head_num in q_head_num_list :
36+ if q_head_num < group_size :
37+ continue
38+ yield bs , group_size , seq_len_m , q_head_dim , q_head_num
39+ for bs , group_size , seq_len_m , q_head_dim , q_head_num in tqdm (list (get_test_configs ())):
40+ kv_head_num = q_head_num // group_size
41+ q_head_dim = q_head_dim
42+ kv_head_dim = q_head_dim
43+ seq_len = (torch .zeros (bs , dtype = torch .int32 ) + seq_len_m ).to (torch .int32 )
44+ total_token_in_the_batch = seq_len .sum ().item ()
45+ rounded_total_token_in_the_batch = (total_token_in_the_batch + 128 - 1 ) // 128 * 128
4246
43- q_shape = [bs , q_head_num , q_head_dim ]
44- kv_shape = [
45- rounded_total_token_in_the_batch ,
46- kv_head_num ,
47- kv_head_dim ,
48- ]
49- qkv_dtype = torch .float16
47+ q_shape = [bs , q_head_num , q_head_dim ]
48+ kv_shape = [
49+ rounded_total_token_in_the_batch ,
50+ kv_head_num ,
51+ kv_head_dim ,
52+ ]
53+ qkv_dtype = torch .float16
5054
51- q , k , v = (
52- torch .randn (q_shape , dtype = qkv_dtype , device = "cuda" ),
53- torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ),
54- torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ),
55- )
56- q , k , v = q / 10 , k / 10 , v / 10
55+ q , k , v = (
56+ torch .randn (q_shape , dtype = qkv_dtype , device = "cuda" ),
57+ torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ),
58+ torch .randn (kv_shape , dtype = qkv_dtype , device = "cuda" ),
59+ )
60+ q , k , v = q / 10 , k / 10 , v / 10
5761
58- req_to_token_index = torch .zeros ((bs , 2048 )) - 1
59- token_index = torch .arange (rounded_total_token_in_the_batch )
62+ req_to_token_index = torch .zeros ((bs , seq_len_m )) - 1
63+ token_index = torch .arange (rounded_total_token_in_the_batch )
6064
61- total_count = 0
62- for i in range (bs ):
63- req_to_token_index [i , : seq_len [i ]] = token_index [
64- total_count : total_count + seq_len [i ]
65- ]
66- total_count += seq_len [i ]
65+ total_count = 0
66+ for i in range (bs ):
67+ req_to_token_index [i , : seq_len [i ]] = token_index [
68+ total_count : total_count + seq_len [i ]
69+ ]
70+ total_count += seq_len [i ]
6771
68- req_to_token_index = req_to_token_index .long ().cuda ()
72+ req_to_token_index = req_to_token_index .long ().cuda ()
6973
70- b_req_idx = torch .arange (bs , device = "cuda" )
71- infer_state = InferStateInfo ()
72- infer_state .req_manager = ReqManager (bs , 2048 , None )
73- infer_state .req_manager .req_to_token_indexs = req_to_token_index
74- infer_state .b_req_idx = b_req_idx .cuda ()
75- infer_state .b_seq_len = seq_len .cuda ()
76- infer_state .max_len_in_batch = 2048
77- infer_state .batch_size = bs
78- infer_state .q_head_num = q_head_num
79- infer_state .q_head_dim = q_head_dim
80- infer_state .kv_head_num = kv_head_num
81- infer_state .softmax_scale = 1 / (q_head_dim ** 0.5 )
82- infer_state .total_token_num = torch .tensor (
83- [total_token_in_the_batch ], dtype = torch .int32
84- ).cuda ()
85- new_out = gqa_token_decode_attention_flash_decoding_vsm (q , k , v , infer_state )
86- old_out = gqa_token_decode_attention_flash_decoding (
87- q ,
88- infer_state ,
89- infer_state .q_head_num ,
90- infer_state .q_head_dim ,
91- k ,
92- v ,
93- )
94- cos_sim = (
95- torch .nn .functional .cosine_similarity (new_out , old_out , dim = - 1 ).mean ().cpu ().item ()
96- )
97- self .assertGreaterEqual (cos_sim , 0.99 )
74+ b_req_idx = torch .arange (bs , device = "cuda" )
75+ infer_state = InferStateInfo ()
76+ infer_state .req_manager = ReqManager (bs , 2048 , None )
77+ infer_state .req_manager .req_to_token_indexs = req_to_token_index
78+ infer_state .b_req_idx = b_req_idx .cuda ()
79+ infer_state .b_seq_len = seq_len .cuda ()
80+ infer_state .max_len_in_batch = seq_len_m
81+ infer_state .batch_size = bs
82+ infer_state .q_head_num = q_head_num
83+ infer_state .q_head_dim = q_head_dim
84+ infer_state .kv_head_num = kv_head_num
85+ infer_state .softmax_scale = 1 / (q_head_dim ** 0.5 )
86+ infer_state .total_token_num = torch .tensor (
87+ [total_token_in_the_batch ], dtype = torch .int32
88+ ).cuda ()
89+ new_out = gqa_token_decode_attention_flash_decoding_vsm (q , k , v , infer_state )
90+ old_out = gqa_token_decode_attention_flash_decoding (
91+ q ,
92+ infer_state ,
93+ infer_state .q_head_num ,
94+ infer_state .q_head_dim ,
95+ k ,
96+ v ,
97+ )
98+ cos_sim = (
99+ torch .nn .functional .cosine_similarity (new_out , old_out , dim = - 1 ).mean ().cpu ().item ()
100+ )
101+ self .assertGreaterEqual (cos_sim , 0.9 , f'bs= { bs } , group_size= { group_size } , seq_len= { seq_len_m } , q_head_dim= { q_head_dim } , q_head_num= { q_head_num } ' )
98102
99103
100104if __name__ == "__main__" :
0 commit comments