@@ -37,7 +37,8 @@ def __init__(
3737 self .b_mark_shared_group = b_mark_shared_group
3838
3939
40- @pytest .mark .parametrize ("shared_seq_len" , [32 ])
40+ # @pytest.mark.parametrize("shared_seq_len", [512])
41+ @pytest .mark .parametrize ("shared_seq_len" , [0 , 77 , 256 , 311 , 512 , 550 ])
4142def test_token_decode_attention_flash_decoding_diverse_vs_baseline (shared_seq_len ):
4243 """
4344 测试 ppl_int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding
@@ -50,10 +51,11 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
5051 token_decode_attention_flash_decoding as baseline_attention ,
5152 )
5253
53- batch_size = 4
54+ batch_size = 6
5455 num_heads = 32
5556 kv_head_num = 8
56- seq_len = 256
57+ mark_shared_group_size = 3
58+ seq_len = 1024
5759 head_dim = 128
5860 quant_group_size = 8
5961 test_dtype = torch .bfloat16
@@ -63,16 +65,24 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
6365 kv_scale_shape = (batch_size * seq_len , kv_head_num , head_dim // quant_group_size )
6466
6567 q = torch .randn (size = (batch_size , num_heads , head_dim ), dtype = test_dtype , device = "cuda" )
68+
69+ # 生成 cache_k 和 cache_v,使得每 mark_shared_group_size 个 batch 共享相同的 cache
70+
6671 cache_k = torch .randint (low = - 100 , high = 100 , size = kv_shape , dtype = torch .int8 , device = "cuda" )
67- cache_k_scale = torch .ones (size = kv_scale_shape , dtype = test_dtype , device = "cuda" )
72+ cache_k_scale = torch .ones (size = kv_scale_shape , dtype = test_dtype , device = "cuda" ) / 100.0
6873 cache_v = torch .randint (low = - 100 , high = 100 , size = kv_shape , dtype = torch .int8 , device = "cuda" )
69- cache_v_scale = torch .ones (size = kv_scale_shape , dtype = test_dtype , device = "cuda" )
74+ cache_v_scale = torch .ones (size = kv_scale_shape , dtype = test_dtype , device = "cuda" ) / 100.0
7075
7176 req_to_tokens = torch .arange (0 , seq_len * batch_size , dtype = torch .int32 , device = "cuda" ).view (batch_size , seq_len )
77+ for i in range (batch_size ):
78+ if i % mark_shared_group_size != 0 :
79+ req_to_tokens [i , :shared_seq_len ] = req_to_tokens [i - 1 , :shared_seq_len ]
80+
7281 b_req_idx = torch .arange (batch_size , dtype = torch .int32 , device = "cuda" )
7382 b_seq_len = torch .full ((batch_size ,), seq_len , dtype = torch .int32 , device = "cuda" )
7483 b_shared_seq_len = torch .full ((batch_size ,), shared_seq_len , dtype = torch .int32 , device = "cuda" )
75- b_mark_shared_group = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" )
84+ b_mark_shared_group = torch .zeros ((batch_size ,), dtype = torch .int32 , device = "cuda" )
85+ b_mark_shared_group [mark_shared_group_size - 1 :: mark_shared_group_size ] = mark_shared_group_size
7686
7787 # 创建 baseline 的 infer_state (不需要 b_shared_seq_len)
7888 baseline_infer_state = MockInferState (
@@ -106,7 +116,6 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
106116 cache_v_scale = cache_v_scale ,
107117 alloc_tensor_func = alloc_tensor_func ,
108118 )
109-
110119 # 运行 diverse 版本
111120 diverse_out = diverse_attention (
112121 q = q .clone (),
0 commit comments