Skip to content

Commit 8d9068b

Browse files
committed
fix unit_test
1 parent 6cdc335 commit 8d9068b

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,15 @@ def token_decode_attention_flash_decoding(
5656
B_req_idx=infer_state.b_req_idx,
5757
b_shared_seq_len=infer_state.b_shared_seq_len,
5858
b_mark_shared_group=infer_state.b_mark_shared_group,
59-
b_seq_len=infer_state.b_seq_len,
6059
max_len_in_batch=infer_state.max_len_in_batch,
6160
mid_out=mid_o,
6261
mid_out_logsumexp=mid_o_logexpsum,
63-
BLOCK_SEQ=BLOCK_SEQ,
62+
block_seq=BLOCK_SEQ,
6463
max_batch_group_size=get_diverse_max_batch_shared_group_size(),
6564
)
6665
stream2.wait_stream(current_stream)
6766
with torch.cuda.stream(stream2):
68-
light_ops.group8_int8kv_flashdecoding_stage1(
67+
light_ops.group8_int8kv_flashdecoding_diverse_stage2(
6968
BLOCK_SEQ,
7069
mid_o,
7170
mid_o_logexpsum,
@@ -78,6 +77,7 @@ def token_decode_attention_flash_decoding(
7877
infer_state.req_manager.req_to_token_indexs,
7978
infer_state.b_req_idx,
8079
infer_state.b_seq_len,
80+
infer_state.b_shared_seq_len,
8181
infer_state.max_len_in_batch,
8282
)
8383

unit_tests/models/llama/test_ppl_int8kv_flash_decoding_diverse.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def __init__(
3737
self.b_mark_shared_group = b_mark_shared_group
3838

3939

40-
@pytest.mark.parametrize("shared_seq_len", [32])
40+
# @pytest.mark.parametrize("shared_seq_len", [512])
41+
@pytest.mark.parametrize("shared_seq_len", [0, 77, 256, 311, 512, 550])
4142
def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len):
4243
"""
4344
测试 ppl_int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding
@@ -50,10 +51,11 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
5051
token_decode_attention_flash_decoding as baseline_attention,
5152
)
5253

53-
batch_size = 4
54+
batch_size = 6
5455
num_heads = 32
5556
kv_head_num = 8
56-
seq_len = 256
57+
mark_shared_group_size = 3
58+
seq_len = 1024
5759
head_dim = 128
5860
quant_group_size = 8
5961
test_dtype = torch.bfloat16
@@ -63,16 +65,24 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
6365
kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size)
6466

6567
q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda")
68+
69+
# 生成 cache_k 和 cache_v,使得每 mark_shared_group_size 个 batch 共享相同的 cache
70+
6671
cache_k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda")
67-
cache_k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda")
72+
cache_k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") / 100.0
6873
cache_v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda")
69-
cache_v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda")
74+
cache_v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda") / 100.0
7075

7176
req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len)
77+
for i in range(batch_size):
78+
if i % mark_shared_group_size != 0:
79+
req_to_tokens[i, :shared_seq_len] = req_to_tokens[i - 1, :shared_seq_len]
80+
7281
b_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda")
7382
b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
7483
b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda")
75-
b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda")
84+
b_mark_shared_group = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
85+
b_mark_shared_group[mark_shared_group_size - 1 :: mark_shared_group_size] = mark_shared_group_size
7686

7787
# 创建 baseline 的 infer_state (不需要 b_shared_seq_len)
7888
baseline_infer_state = MockInferState(
@@ -106,7 +116,6 @@ def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_le
106116
cache_v_scale=cache_v_scale,
107117
alloc_tensor_func=alloc_tensor_func,
108118
)
109-
110119
# 运行 diverse 版本
111120
diverse_out = diverse_attention(
112121
q=q.clone(),

0 commit comments

Comments
 (0)