Skip to content

Commit 2ae7084

Browse files
committed
add stage2 unit_test
1 parent d35a2cf commit 2ae7084

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import pytest
2+
import torch
3+
from lightllm.utils.light_utils import light_ops
4+
5+
6+
def create_tensors(shared_seq_len):
7+
batch_size = 4
8+
num_heads = 4
9+
kv_head_num = 1
10+
seq_len = 256
11+
head_dim = 128
12+
max_len_in_batch = seq_len
13+
block_seq = 256
14+
max_batch_group_size = 4
15+
quant_group_size = 8
16+
17+
test_dtype = torch.bfloat16
18+
19+
kv_shape = (batch_size * seq_len, kv_head_num, head_dim)
20+
kv_scale_shape = (batch_size * seq_len, kv_head_num, head_dim // quant_group_size)
21+
22+
q = torch.randn(size=(batch_size, num_heads, head_dim), dtype=test_dtype, device="cuda")
23+
k = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda")
24+
k_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda")
25+
v = torch.randint(low=-100, high=100, size=kv_shape, dtype=torch.int8, device="cuda")
26+
v_scale = torch.ones(size=kv_scale_shape, dtype=test_dtype, device="cuda")
27+
Req_to_tokens = torch.arange(0, seq_len * batch_size, dtype=torch.int32, device="cuda").view(batch_size, seq_len)
28+
B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda")
29+
b_seq_len = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
30+
b_shared_seq_len = torch.full((batch_size,), shared_seq_len, dtype=torch.int32, device="cuda")
31+
b_mark_shared_group = torch.ones(batch_size, dtype=torch.int32, device="cuda")
32+
mid_out = torch.zeros(
33+
size=(batch_size, num_heads, (seq_len // block_seq) + 2, head_dim), dtype=q.dtype, device="cuda"
34+
)
35+
mid_out_logsumexp = torch.zeros(
36+
size=(batch_size, num_heads, (seq_len // block_seq) + 2), dtype=q.dtype, device="cuda"
37+
)
38+
39+
return {
40+
"q": q,
41+
"k": k,
42+
"k_scale": k_scale,
43+
"v": v,
44+
"v_scale": v_scale,
45+
"Req_to_tokens": Req_to_tokens,
46+
"B_req_idx": B_req_idx,
47+
"b_seq_len": b_seq_len,
48+
"b_shared_seq_len": b_shared_seq_len,
49+
"b_mark_shared_group": b_mark_shared_group,
50+
"max_len_in_batch": max_len_in_batch,
51+
"mid_out": mid_out,
52+
"mid_out_logsumexp": mid_out_logsumexp,
53+
"block_seq": block_seq,
54+
"max_batch_group_size": max_batch_group_size,
55+
"head_dim": head_dim,
56+
}
57+
58+
59+
@pytest.mark.parametrize("shared_seq_len", [0, 47, 77, 128, 200, 255])
60+
def test_flash_decode_stage2_execution(shared_seq_len):
61+
setup_tensors = create_tensors(shared_seq_len)
62+
63+
light_ops.group8_int8kv_flashdecoding_diverse_stage2(
64+
setup_tensors["block_seq"],
65+
setup_tensors["mid_out"],
66+
setup_tensors["mid_out_logsumexp"],
67+
1.0 / (setup_tensors["head_dim"] ** 0.5),
68+
setup_tensors["q"],
69+
setup_tensors["k"],
70+
setup_tensors["k_scale"],
71+
setup_tensors["v"],
72+
setup_tensors["v_scale"],
73+
setup_tensors["Req_to_tokens"],
74+
setup_tensors["B_req_idx"],
75+
setup_tensors["b_seq_len"],
76+
setup_tensors["b_shared_seq_len"],
77+
setup_tensors["max_len_in_batch"],
78+
)
79+
seq_block_idx = (setup_tensors["b_shared_seq_len"][0].item() + setup_tensors["block_seq"] - 1) // setup_tensors[
80+
"block_seq"
81+
]
82+
mid_out = setup_tensors["mid_out"][:, :, seq_block_idx:, :]
83+
mid_out_logsumexp = setup_tensors["mid_out_logsumexp"][:, :, seq_block_idx:]
84+
85+
q = setup_tensors["q"]
86+
k = setup_tensors["k"]
87+
v = setup_tensors["v"]
88+
true_mid_out = torch.zeros_like(mid_out)
89+
true_mid_out_logsumexp = torch.zeros_like(mid_out_logsumexp)
90+
new_q = q
91+
new_k = k.to(q.dtype)
92+
new_v = v.to(q.dtype)
93+
94+
b_seq_len = setup_tensors["b_seq_len"] - setup_tensors["b_shared_seq_len"]
95+
req_to_tokens = setup_tensors["Req_to_tokens"][:, setup_tensors["b_shared_seq_len"][0].item() :]
96+
97+
from lightllm.models.llama.triton_kernel.gqa_flash_decoding_stage1 import (
98+
flash_decode_stage1 as gqa_flash_decode_stage1,
99+
)
100+
101+
gqa_flash_decode_stage1(
102+
q=new_q,
103+
k=new_k,
104+
v=new_v,
105+
Req_to_tokens=req_to_tokens,
106+
B_req_idx=setup_tensors["B_req_idx"],
107+
B_Seqlen=b_seq_len,
108+
max_len_in_batch=setup_tensors["max_len_in_batch"],
109+
mid_out=true_mid_out,
110+
mid_out_logsumexp=true_mid_out_logsumexp,
111+
block_seq=setup_tensors["block_seq"],
112+
)
113+
print(f"\nshared_seq_len={shared_seq_len}")
114+
print(f"mid_out: {mid_out[0:4, 0, 0, 0]}")
115+
print(f"true_mid_out: {true_mid_out[0:4, 0, 0, 0]}")
116+
assert torch.allclose(
117+
mid_out[0:4, 0, 0, 0], true_mid_out[0:4, 0, 0, 0], atol=1e-2
118+
), f"Mid output does not match expected values for shared_seq_len={shared_seq_len}"
119+
assert torch.allclose(
120+
mid_out_logsumexp, true_mid_out_logsumexp, atol=1e-2
121+
), f"LogSumExp output does not match expected values for shared_seq_len={shared_seq_len}"

0 commit comments

Comments
 (0)