Skip to content

Commit 97416f4

Browse files
add unit test for kv trans kernel (#1059)
1 parent fe54874 commit 97416f4

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import os
2+
import pytest
3+
import torch
4+
5+
6+
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for Triton kernels")
7+
8+
9+
def _randn_like(shape, dtype=torch.float32, device="cuda"):
10+
torch.manual_seed(1234)
11+
return torch.randn(shape, dtype=dtype, device=device).contiguous()
12+
13+
14+
def _gather_along_size(t, idx):
15+
return t.index_select(dim=1, index=idx)
16+
17+
18+
@pytest.mark.parametrize("head_dim", [64, 96])
19+
@pytest.mark.parametrize(
20+
"kv_head_num,tp_world_size",
21+
[
22+
(4, 1),
23+
(4, 2),
24+
(8, 1),
25+
(8, 2),
26+
(8, 4),
27+
(16, 2),
28+
],
29+
)
30+
def test_page_io_roundtrip_with_tp(head_dim, kv_head_num, tp_world_size):
31+
from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io
32+
33+
device = "cuda"
34+
dtype = torch.bfloat16
35+
36+
# Shapes
37+
layer_num = 3
38+
size = 32
39+
assert kv_head_num % 2 == 0
40+
page_size = 32
41+
page_head_num = kv_head_num * tp_world_size
42+
43+
tp_indices = list(range(tp_world_size))
44+
45+
kv_buffer = _randn_like((layer_num, size, kv_head_num, head_dim), dtype=dtype, device=device)
46+
page_tensor = torch.zeros((page_size, layer_num, page_head_num, head_dim), dtype=dtype, device=device).contiguous()
47+
48+
# Select a handful of token positions to move (tid count <= page_size)
49+
mem_indexes = torch.tensor([2, 5, 7, 9, 11], dtype=torch.int64, device=device).contiguous()
50+
51+
# Write: kv_buffer -> page_tensor, done by all tp ranks to fill their partition
52+
for tp_index in tp_indices:
53+
page_io(
54+
mem_indexes=mem_indexes,
55+
page_tensor=page_tensor,
56+
kv_buffer=kv_buffer,
57+
tp_index=tp_index,
58+
tp_world_size=tp_world_size,
59+
mode="write",
60+
)
61+
62+
# After-write expectation check
63+
token_num = mem_indexes.numel()
64+
k_head_num = kv_head_num // 2
65+
page_k_head_num = page_head_num // 2
66+
page_write_head_num = page_k_head_num // tp_world_size
67+
68+
expected_page = torch.zeros_like(page_tensor)
69+
for tid in range(token_num):
70+
mem_index = int(mem_indexes[tid].item())
71+
for layer_index in range(layer_num):
72+
for tp_index in tp_indices:
73+
page_head_start = tp_index * page_write_head_num
74+
for kv_head_id in range(page_write_head_num):
75+
# K half
76+
expected_page[tid, layer_index, page_head_start + kv_head_id, :] = kv_buffer[
77+
layer_index, mem_index, kv_head_id, :
78+
]
79+
# V half
80+
expected_page[tid, layer_index, page_k_head_num + page_head_start + kv_head_id, :] = kv_buffer[
81+
layer_index, mem_index, k_head_num + kv_head_id, :
82+
]
83+
84+
assert torch.allclose(page_tensor[:token_num], expected_page[:token_num], atol=1e-3, rtol=1e-3)
85+
86+
# Read back to a fresh buffer
87+
kv_buffer_rt = torch.zeros_like(kv_buffer)
88+
for tp_index in tp_indices:
89+
page_io(
90+
mem_indexes=mem_indexes,
91+
page_tensor=page_tensor,
92+
kv_buffer=kv_buffer_rt,
93+
tp_index=tp_index,
94+
tp_world_size=tp_world_size,
95+
mode="read",
96+
)
97+
98+
# Check equality only at selected positions along size-dim
99+
ref = _gather_along_size(kv_buffer, mem_indexes)
100+
out = _gather_along_size(kv_buffer_rt, mem_indexes)
101+
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)
102+
103+
104+
@pytest.mark.parametrize("head_dim", [32, 80]) # include non-power-of-two
105+
def test_mla_page_io_roundtrip(head_dim):
106+
from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io
107+
108+
device = "cuda"
109+
dtype = torch.bfloat16
110+
111+
# Shapes (single-head)
112+
layer_num = 2
113+
size = 20
114+
kv_head_num = 1
115+
page_size = 10
116+
page_head_num = 1
117+
118+
kv_buffer = _randn_like((layer_num, size, kv_head_num, head_dim), dtype=dtype, device=device)
119+
page_tensor = torch.zeros((page_size, layer_num, page_head_num, head_dim), dtype=dtype, device=device).contiguous()
120+
121+
mem_indexes = torch.tensor([0, 3, 6, 7], dtype=torch.int64, device=device).contiguous()
122+
123+
# Write kv -> page
124+
mla_page_io(mem_indexes=mem_indexes, page_tensor=page_tensor, kv_buffer=kv_buffer, mode="write")
125+
126+
# After-write expectation check
127+
token_num = mem_indexes.numel()
128+
expected_page = torch.zeros_like(page_tensor)
129+
for tid in range(token_num):
130+
mem_index = int(mem_indexes[tid].item())
131+
for layer_index in range(layer_num):
132+
expected_page[tid, layer_index, 0, :] = kv_buffer[layer_index, mem_index, 0, :]
133+
assert torch.allclose(page_tensor[:token_num], expected_page[:token_num], atol=1e-3, rtol=1e-3)
134+
135+
# Read back page -> kv
136+
kv_buffer_rt = torch.zeros_like(kv_buffer)
137+
mla_page_io(mem_indexes=mem_indexes, page_tensor=page_tensor, kv_buffer=kv_buffer_rt, mode="read")
138+
139+
ref = kv_buffer.index_select(dim=1, index=mem_indexes)
140+
out = kv_buffer_rt.index_select(dim=1, index=mem_indexes)
141+
assert torch.allclose(out, ref, atol=1e-3, rtol=1e-3)

0 commit comments

Comments
 (0)