Skip to content

Commit b6cfc2c

Browse files
Test refactoring and fixes (#1736)
<!-- .github/pull_request_template.md --> ## 📌 Description Unit test fixes: * Refactored test_mla_decode_kernel to run from pytest * Added skip to test_mnnvl_custom_comm when world size is too large * Added asserts to the cascade API when not using Hopper ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [V] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [V] Tests have been added or updated as needed. - [V] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 7d7aa87 commit b6cfc2c

File tree

6 files changed

+126
-78
lines changed

6 files changed

+126
-78
lines changed

flashinfer/triton/cascade.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
)
1111
from .utils import check_device, check_dim, check_input, check_shape
1212

13+
EXPECT_HOPPER = 9
14+
1315

1416
def merge_state(
1517
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
@@ -18,7 +20,7 @@ def merge_state(
1820
check_input(s_a)
1921
check_input(v_b)
2022
check_input(s_b)
21-
check_device([v_a, s_a, v_b, s_b])
23+
check_device([v_a, s_a, v_b, s_b], major=[EXPECT_HOPPER])
2224
check_dim(3, v_a)
2325
check_dim(2, s_a)
2426
check_dim(3, v_b)
@@ -55,7 +57,7 @@ def merge_state_in_place(
5557
check_input(s)
5658
check_input(v_other)
5759
check_input(s_other)
58-
check_device([v, s, v_other, s_other])
60+
check_device([v, s, v_other, s_other], major=[EXPECT_HOPPER])
5961
check_dim(3, v)
6062
check_dim(2, s)
6163
check_dim(3, v_other)
@@ -84,7 +86,7 @@ def merge_state_in_place(
8486
def merge_states(v: torch.Tensor, s: torch.Tensor):
8587
check_input(v)
8688
check_input(s)
87-
check_device([v, s])
89+
check_device([v, s], major=[EXPECT_HOPPER])
8890
check_dim(4, v)
8991
check_dim(3, s)
9092
assert v.size(0) == s.size(0)
@@ -121,7 +123,7 @@ def variable_length_merge_states(
121123
):
122124
check_input(v)
123125
check_input(s)
124-
check_device([v, s])
126+
check_device([v, s], major=[EXPECT_HOPPER])
125127
check_dim(3, v)
126128
check_dim(2, s)
127129
assert v.size(0) == s.size(0)

flashinfer/triton/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List
22

33
import torch
4+
from flashinfer.utils import get_compute_capability, GPUArchitectureError
45

56

67
def check_input(x: torch.Tensor):
@@ -20,9 +21,21 @@ def check_shape(a: torch.Tensor, b: torch.Tensor):
2021
)
2122

2223

23-
def check_device(tensors: List[torch.Tensor]):
24+
def check_device(
25+
tensors: List[torch.Tensor], major: List[int] = None, minor: List[int] = None
26+
):
2427
device = tensors[0].device
2528
for t in tensors:
2629
assert t.device == device, (
2730
f"All tensors should be on the same device, but got {device} and {t.device}"
2831
)
32+
if major is not None or minor is not None:
33+
actual_major, actual_minor = get_compute_capability(device)
34+
if major is not None and actual_major not in major:
35+
raise GPUArchitectureError(
36+
f"Device major should be in {major}, but got {actual_major}"
37+
)
38+
if minor is not None and actual_minor not in minor:
39+
raise GPUArchitectureError(
40+
f"Device minor should be in {minor}, but got {actual_minor}"
41+
)

flashinfer/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ class TensorLayout(Enum):
5151
log2e = 1.44269504088896340736
5252

5353

54+
class GPUArchitectureError(Exception):
55+
def __init__(self, msg: str):
56+
self.msg = msg
57+
super().__init__(self.msg)
58+
59+
def __str__(self):
60+
return self.msg
61+
62+
def __repr__(self):
63+
return self.msg
64+
65+
5466
def _expand_5d(x: torch.Tensor, kv_layout: str) -> torch.Tensor:
5567
if x.ndim not in [4, 5]:
5668
raise ValueError("x must be 4D or 5D")

tests/test_mla_decode_kernel.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from typing import Optional, Tuple
2+
import pytest
23

34
import torch
45
import torch.nn.functional as F
56
from torch import nn
67

78
import flashinfer
9+
from rope_reference import apply_rotary_emb, precompute_freqs_cis
810

911

1012
def wmape(target: torch.Tensor, preds: torch.Tensor):
@@ -13,9 +15,6 @@ def wmape(target: torch.Tensor, preds: torch.Tensor):
1315
return sum_abs_error / sum_scale
1416

1517

16-
from rope_reference import *
17-
18-
1918
class DeepseekV2RMSNorm(nn.Module):
2019
def __init__(self, hidden_size, eps=1e-6):
2120
"""
@@ -247,6 +246,10 @@ def run_proof_of_concept(
247246
k_pe_cache: torch.Tensor,
248247
use_flashinfer_kernel: bool,
249248
convert_float16: bool,
249+
bsz: int,
250+
kv_len: int,
251+
page_size: int,
252+
dev_id: int,
250253
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
251254
c_Q = torch.matmul(hidden_states, self.W_DQ)
252255
# c_Q ~ [bsz, q_lora_rank:1536]
@@ -392,18 +395,17 @@ def run_proof_of_concept(
392395
return output
393396

394397

395-
if __name__ == "__main__":
398+
@pytest.mark.parametrize("bsz", [6])
399+
@pytest.mark.parametrize("kv_len", [640])
400+
@pytest.mark.parametrize("page_size", [16])
401+
def test_mla_decode_kernel(bsz, kv_len, page_size):
396402
dev_id = 0
397403

398404
torch.manual_seed(666)
399405
torch.set_grad_enabled(False)
400406

401407
mla_vanilla = DeepseekV2AttentionVanilla().cuda(device=dev_id)
402408

403-
bsz = 6
404-
kv_len = 640
405-
page_size = 16
406-
407409
hidden_states = torch.randn([bsz, 1, mla_vanilla.hidden_size]).to(dev_id)
408410
compressed_kv_normed_cache = torch.randn(
409411
[bsz, kv_len, mla_vanilla.kv_lora_rank]
@@ -421,20 +423,32 @@ def run_proof_of_concept(
421423
k_pe_cache,
422424
use_flashinfer_kernel=False,
423425
convert_float16=False,
426+
bsz=bsz,
427+
kv_len=kv_len,
428+
page_size=page_size,
429+
dev_id=dev_id,
424430
)
425431
output_mat_absorbed_use_torch_f16 = mla_mat_absorb.run_proof_of_concept(
426432
hidden_states.squeeze(1),
427433
compressed_kv_normed_cache,
428434
k_pe_cache,
429435
use_flashinfer_kernel=False,
430436
convert_float16=True,
437+
bsz=bsz,
438+
kv_len=kv_len,
439+
page_size=page_size,
440+
dev_id=dev_id,
431441
)
432442
output_mat_absorbed_use_flashinfer = mla_mat_absorb.run_proof_of_concept(
433443
hidden_states.squeeze(1),
434444
compressed_kv_normed_cache,
435445
k_pe_cache,
436446
use_flashinfer_kernel=True,
437447
convert_float16=True,
448+
bsz=bsz,
449+
kv_len=kv_len,
450+
page_size=page_size,
451+
dev_id=dev_id,
438452
)
439453

440454
cos_use_torch_f32 = F.cosine_similarity(
@@ -489,3 +503,10 @@ def run_proof_of_concept(
489503
output_vanilla.reshape(-1), output_mat_absorbed_use_flashinfer.reshape(-1)
490504
)
491505
print(f"mse_use_flashinfer = {mse_use_flashinfer}")
506+
507+
508+
if __name__ == "__main__":
509+
bsz = 6
510+
kv_len = 640
511+
page_size = 16
512+
test_mla_decode_kernel(bsz, kv_len, page_size)

tests/test_mnnvl_custom_comm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,10 @@ def test_mnnvl_custom_communicator(world_size):
169169
dtype = torch.float16
170170
available_gpus = torch.cuda.device_count()
171171
if world_size > available_gpus:
172-
raise ValueError(
172+
pytest.skip(
173173
f"world_size {world_size} is greater than available_gpus {available_gpus}"
174174
)
175+
175176
print(f"Running test for world_size={world_size}")
176177

177178
multi_process_parallel(

tests/test_triton_cascade.py

Lines changed: 63 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,96 +3,95 @@
33

44
import flashinfer
55
import flashinfer.triton
6-
from flashinfer.utils import get_compute_capability
6+
from flashinfer.utils import GPUArchitectureError
77

88

99
@pytest.mark.parametrize("seq_len", [2048])
1010
@pytest.mark.parametrize("num_heads", [32])
1111
@pytest.mark.parametrize("head_dim", [128])
1212
def test_merge_state(seq_len, num_heads, head_dim):
13-
compute_capability = get_compute_capability(torch.device(device="cuda"))
14-
if compute_capability[0] != 9:
15-
pytest.skip("These tests are only guaranteed to work on Hopper GPUs.")
13+
try:
14+
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
15+
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
16+
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
17+
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
18+
v_merged, s_merged = flashinfer.triton.cascade.merge_state(va, sa, vb, sb)
19+
v_merged_std, s_merged_std = flashinfer.merge_state(va, sa, vb, sb)
1620

17-
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
18-
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
19-
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
20-
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
21-
v_merged, s_merged = flashinfer.triton.cascade.merge_state(va, sa, vb, sb)
22-
v_merged_std, s_merged_std = flashinfer.merge_state(va, sa, vb, sb)
23-
24-
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
25-
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
21+
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
22+
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
23+
except GPUArchitectureError as e:
24+
pytest.skip(e.msg)
2625

2726

2827
@pytest.mark.parametrize("seq_len", [2048])
2928
@pytest.mark.parametrize("num_heads", [32])
3029
@pytest.mark.parametrize("head_dim", [128])
3130
def test_merge_state_in_place(seq_len, num_heads, head_dim):
32-
compute_capability = get_compute_capability(torch.device(device="cuda"))
33-
if compute_capability[0] != 9:
34-
pytest.skip("These tests are only guaranteed to work on Hopper GPUs.")
31+
try:
32+
v = torch.randn(seq_len, num_heads, head_dim).half()
33+
v_std = v.clone()
34+
v, v_std = v.to("cuda:0"), v_std.to("cuda:0")
35+
s = torch.randn(seq_len, num_heads, dtype=torch.float32)
36+
s_std = s.clone()
37+
s, s_std = s.to("cuda:0"), s_std.to("cuda:0")
38+
v_other = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
39+
s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
40+
flashinfer.merge_state_in_place(v_std, s_std, v_other, s_other)
41+
flashinfer.triton.cascade.merge_state_in_place(v, s, v_other, s_other)
3542

36-
v = torch.randn(seq_len, num_heads, head_dim).half()
37-
v_std = v.clone()
38-
v, v_std = v.to("cuda:0"), v_std.to("cuda:0")
39-
s = torch.randn(seq_len, num_heads, dtype=torch.float32)
40-
s_std = s.clone()
41-
s, s_std = s.to("cuda:0"), s_std.to("cuda:0")
42-
v_other = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
43-
s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
44-
flashinfer.merge_state_in_place(v_std, s_std, v_other, s_other)
45-
flashinfer.triton.cascade.merge_state_in_place(v, s, v_other, s_other)
43+
assert torch.allclose(v, v_std, atol=1e-2)
44+
assert torch.allclose(s, s_std, atol=1e-2)
4645

47-
assert torch.allclose(v, v_std, atol=1e-2)
48-
assert torch.allclose(s, s_std, atol=1e-2)
46+
except GPUArchitectureError as e:
47+
pytest.skip(e.msg)
4948

5049

5150
@pytest.mark.parametrize("seq_len", [2048])
5251
@pytest.mark.parametrize("num_heads", [32])
5352
@pytest.mark.parametrize("head_dim", [128])
5453
@pytest.mark.parametrize("num_states", [100])
5554
def test_merge_states(seq_len, num_states, num_heads, head_dim):
56-
compute_capability = get_compute_capability(torch.device(device="cuda"))
57-
if compute_capability[0] != 9:
58-
pytest.skip("These tests are only guaranteed to work on Hopper GPUs.")
59-
60-
v = torch.randn(seq_len, num_states, num_heads, head_dim).half().to("cuda:0")
61-
s = torch.randn(seq_len, num_states, num_heads, dtype=torch.float32).to("cuda:0")
62-
v_merged_std, s_merged_std = flashinfer.merge_states(v, s)
63-
v_merged, s_merged = flashinfer.triton.cascade.merge_states(v, s)
55+
try:
56+
v = torch.randn(seq_len, num_states, num_heads, head_dim).half().to("cuda:0")
57+
s = torch.randn(seq_len, num_states, num_heads, dtype=torch.float32).to(
58+
"cuda:0"
59+
)
60+
v_merged_std, s_merged_std = flashinfer.merge_states(v, s)
61+
v_merged, s_merged = flashinfer.triton.cascade.merge_states(v, s)
6462

65-
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
66-
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
63+
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
64+
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
65+
except GPUArchitectureError as e:
66+
pytest.skip(e.msg)
6767

6868

6969
@pytest.mark.parametrize("seq_len", [2048])
7070
@pytest.mark.parametrize("num_heads", [32])
7171
@pytest.mark.parametrize("head_dim", [128])
7272
def test_variable_length_merge_states(seq_len, num_heads, head_dim):
73-
compute_capability = get_compute_capability(torch.device(device="cuda"))
74-
if compute_capability[0] != 9:
75-
pytest.skip("These tests are only guaranteed to work on Hopper GPUs.")
76-
77-
max_index_sets = 512
78-
lengths = torch.randint(low=1, high=max_index_sets, size=(seq_len,))
79-
indptr = [0]
80-
for i in range(seq_len):
81-
indptr.append(indptr[-1] + lengths[i])
82-
v = torch.randn(indptr[-1], num_heads, head_dim).half().to("cuda:0")
83-
s = torch.randn(indptr[-1], num_heads, dtype=torch.float32).to("cuda:0")
84-
indptr = torch.tensor(indptr, dtype=torch.int32).to("cuda:0")
85-
v_merged, s_merged = flashinfer.triton.cascade.variable_length_merge_states(
86-
v, s, indptr
87-
)
88-
for i in range(seq_len):
89-
sub_v = v[indptr[i] : indptr[i + 1]]
90-
sub_s = s[indptr[i] : indptr[i + 1]]
91-
sub_v = torch.unsqueeze(sub_v, 0)
92-
sub_s = torch.unsqueeze(sub_s, 0)
93-
v_merged_std, s_merged_std = flashinfer.merge_states(sub_v, sub_s)
94-
v_merged_std = torch.squeeze(v_merged_std, 0)
95-
s_merged_std = torch.squeeze(s_merged_std, 0)
96-
assert v_merged[i].shape == v_merged_std.shape
97-
assert torch.allclose(v_merged[i], v_merged_std, atol=1e-2)
98-
assert torch.allclose(s_merged[i], s_merged_std, atol=1e-2)
73+
try:
74+
max_index_sets = 512
75+
lengths = torch.randint(low=1, high=max_index_sets, size=(seq_len,))
76+
indptr = [0]
77+
for i in range(seq_len):
78+
indptr.append(indptr[-1] + lengths[i])
79+
v = torch.randn(indptr[-1], num_heads, head_dim).half().to("cuda:0")
80+
s = torch.randn(indptr[-1], num_heads, dtype=torch.float32).to("cuda:0")
81+
indptr = torch.tensor(indptr, dtype=torch.int32).to("cuda:0")
82+
v_merged, s_merged = flashinfer.triton.cascade.variable_length_merge_states(
83+
v, s, indptr
84+
)
85+
for i in range(seq_len):
86+
sub_v = v[indptr[i] : indptr[i + 1]]
87+
sub_s = s[indptr[i] : indptr[i + 1]]
88+
sub_v = torch.unsqueeze(sub_v, 0)
89+
sub_s = torch.unsqueeze(sub_s, 0)
90+
v_merged_std, s_merged_std = flashinfer.merge_states(sub_v, sub_s)
91+
v_merged_std = torch.squeeze(v_merged_std, 0)
92+
s_merged_std = torch.squeeze(s_merged_std, 0)
93+
assert v_merged[i].shape == v_merged_std.shape
94+
assert torch.allclose(v_merged[i], v_merged_std, atol=1e-2)
95+
assert torch.allclose(s_merged[i], s_merged_std, atol=1e-2)
96+
except GPUArchitectureError as e:
97+
pytest.skip(e.msg)

0 commit comments

Comments
 (0)