Skip to content

Commit 6cdc335

Browse files
committed
add test_token_decode_attention_flash_decoding_diverse but failed
1 parent 5d295e0 commit 6cdc335

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import pytest
2+
import torch
3+
from lightllm.utils.light_utils import light_ops
4+
5+
6+
def alloc_tensor_func(shape, dtype, device):
7+
"""兼容的 tensor 分配函数"""
8+
return torch.empty(shape, dtype=dtype, device=device)
9+
10+
11+
class MockReqManager:
12+
"""Mock request manager for testing"""
13+
14+
def __init__(self, req_to_token_indexs):
15+
self.req_to_token_indexs = req_to_token_indexs
16+
17+
18+
class MockInferState:
19+
"""Mock infer state for testing"""
20+
21+
def __init__(
22+
self,
23+
batch_size,
24+
max_len_in_batch,
25+
req_to_tokens,
26+
b_req_idx,
27+
b_seq_len,
28+
b_shared_seq_len=None,
29+
b_mark_shared_group=None,
30+
):
31+
self.batch_size = batch_size
32+
self.max_len_in_batch = max_len_in_batch
33+
self.req_manager = MockReqManager(req_to_tokens)
34+
self.b_req_idx = b_req_idx
35+
self.b_seq_len = b_seq_len
36+
self.b_shared_seq_len = b_shared_seq_len
37+
self.b_mark_shared_group = b_mark_shared_group
38+
39+
40+
@pytest.mark.parametrize("shared_seq_len", [32])
41+
def test_token_decode_attention_flash_decoding_diverse_vs_baseline(shared_seq_len):
42+
"""
43+
测试 ppl_int8kv_flash_decoding_diverse 的 token_decode_attention_flash_decoding
44+
与 ppl_int8kv_flash_decoding (baseline) 的对比。
45+
"""
46+
from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse import (
47+
token_decode_attention_flash_decoding as diverse_attention,
48+
)
49+
from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding import (
50+
token_decode_attention_flash_decoding as baseline_attention,
51+
)
52+
53+
batch_size = 4
54+
num_heads = 32
55+
kv_head_num = 8
56+
seq_len = 256
57+
head_dim = 128
58+
quant_group_size = 8
59+
test_dtype = torch.bfloat16
60+
61+
# 创建测试数据
62+
kv_shape = (batch_size * seq_len, kv_head_num, head_dim)
63+
kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size)
64+
65+
q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda")
66+
cache_k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda")
67+
cache_k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda")
68+
cache_v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda")
69+
cache_v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda")
70+
71+
req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len)
72+
b_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda")
73+
b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
74+
b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda")
75+
b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda")
76+
77+
# 创建 baseline 的 infer_state (不需要 b_shared_seq_len)
78+
baseline_infer_state = MockInferState(
79+
batch_size=batch_size,
80+
max_len_in_batch=seq_len,
81+
req_to_tokens=req_to_tokens,
82+
b_req_idx=b_req_idx,
83+
b_seq_len=b_seq_len,
84+
)
85+
86+
# 创建 diverse 的 infer_state
87+
diverse_infer_state = MockInferState(
88+
batch_size=batch_size,
89+
max_len_in_batch=seq_len,
90+
req_to_tokens=req_to_tokens,
91+
b_req_idx=b_req_idx,
92+
b_seq_len=b_seq_len,
93+
b_shared_seq_len=b_shared_seq_len,
94+
b_mark_shared_group=b_mark_shared_group,
95+
)
96+
97+
# 运行 baseline
98+
baseline_out = baseline_attention(
99+
q=q.clone(),
100+
infer_state=baseline_infer_state,
101+
q_head_num=num_heads,
102+
head_dim=head_dim,
103+
cache_k=cache_k,
104+
cache_k_scale=cache_k_scale,
105+
cache_v=cache_v,
106+
cache_v_scale=cache_v_scale,
107+
alloc_tensor_func=alloc_tensor_func,
108+
)
109+
110+
# 运行 diverse 版本
111+
diverse_out = diverse_attention(
112+
q=q.clone(),
113+
infer_state=diverse_infer_state,
114+
q_head_num=num_heads,
115+
head_dim=head_dim,
116+
cache_k=cache_k,
117+
cache_k_scale=cache_k_scale,
118+
cache_v=cache_v,
119+
cache_v_scale=cache_v_scale,
120+
alloc_tensor_func=alloc_tensor_func,
121+
)
122+
123+
print(f"\nshared_seq_len={shared_seq_len}")
124+
print(f"baseline_out: {baseline_out[0, 0, :4]}")
125+
print(f"diverse_out: {diverse_out[0, 0, :4]}")
126+
print(f"max diff: {(baseline_out - diverse_out).abs().max()}")
127+
128+
# 与 baseline 对比
129+
assert torch.allclose(
130+
baseline_out, diverse_out, atol=1e-2, rtol=1e-2
131+
), f"Diverse attention output should match baseline for shared_seq_len={shared_seq_len}"

0 commit comments

Comments
 (0)