Skip to content

Commit ea56964

Browse files
authored
fix error message (#1789)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description fix "AttributeError: 'GPUArchitectureError' object has no attribute 'msg'" error. ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 4f884bf commit ea56964

File tree

2 files changed

+54
-65
lines changed

2 files changed

+54
-65
lines changed

β€Žtests/conftest.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def wrapper(*args, **kwargs):
152152
try:
153153
return func(*args, **kwargs)
154154
except flashinfer.utils.GPUArchitectureError as e:
155-
pytest.skip(e.msg)
155+
pytest.skip(str(e))
156156

157157
return wrapper
158158

β€Žtests/test_triton_cascade.pyβ€Ž

Lines changed: 53 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,95 +3,84 @@
33

44
import flashinfer
55
import flashinfer.triton
6-
from flashinfer.utils import GPUArchitectureError
6+
from conftest import skip_on_gpu_arch_error
77

88

9+
@skip_on_gpu_arch_error
910
@pytest.mark.parametrize("seq_len", [2048])
1011
@pytest.mark.parametrize("num_heads", [32])
1112
@pytest.mark.parametrize("head_dim", [128])
1213
def test_merge_state(seq_len, num_heads, head_dim):
13-
try:
14-
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
15-
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
16-
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
17-
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
18-
v_merged, s_merged = flashinfer.triton.cascade.merge_state(va, sa, vb, sb)
19-
v_merged_std, s_merged_std = flashinfer.merge_state(va, sa, vb, sb)
14+
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
15+
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
16+
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
17+
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
18+
v_merged, s_merged = flashinfer.triton.cascade.merge_state(va, sa, vb, sb)
19+
v_merged_std, s_merged_std = flashinfer.merge_state(va, sa, vb, sb)
2020

21-
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
22-
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
23-
except GPUArchitectureError as e:
24-
pytest.skip(str(e))
21+
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
22+
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
2523

2624

25+
@skip_on_gpu_arch_error
2726
@pytest.mark.parametrize("seq_len", [2048])
2827
@pytest.mark.parametrize("num_heads", [32])
2928
@pytest.mark.parametrize("head_dim", [128])
3029
def test_merge_state_in_place(seq_len, num_heads, head_dim):
31-
try:
32-
v = torch.randn(seq_len, num_heads, head_dim).half()
33-
v_std = v.clone()
34-
v, v_std = v.to("cuda:0"), v_std.to("cuda:0")
35-
s = torch.randn(seq_len, num_heads, dtype=torch.float32)
36-
s_std = s.clone()
37-
s, s_std = s.to("cuda:0"), s_std.to("cuda:0")
38-
v_other = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
39-
s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
40-
flashinfer.merge_state_in_place(v_std, s_std, v_other, s_other)
41-
flashinfer.triton.cascade.merge_state_in_place(v, s, v_other, s_other)
30+
v = torch.randn(seq_len, num_heads, head_dim).half()
31+
v_std = v.clone()
32+
v, v_std = v.to("cuda:0"), v_std.to("cuda:0")
33+
s = torch.randn(seq_len, num_heads, dtype=torch.float32)
34+
s_std = s.clone()
35+
s, s_std = s.to("cuda:0"), s_std.to("cuda:0")
36+
v_other = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
37+
s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
38+
flashinfer.merge_state_in_place(v_std, s_std, v_other, s_other)
39+
flashinfer.triton.cascade.merge_state_in_place(v, s, v_other, s_other)
4240

43-
assert torch.allclose(v, v_std, atol=1e-2)
44-
assert torch.allclose(s, s_std, atol=1e-2)
45-
46-
except GPUArchitectureError as e:
47-
pytest.skip(str(e))
41+
assert torch.allclose(v, v_std, atol=1e-2)
42+
assert torch.allclose(s, s_std, atol=1e-2)
4843

4944

45+
@skip_on_gpu_arch_error
5046
@pytest.mark.parametrize("seq_len", [2048])
5147
@pytest.mark.parametrize("num_heads", [32])
5248
@pytest.mark.parametrize("head_dim", [128])
5349
@pytest.mark.parametrize("num_states", [100])
5450
def test_merge_states(seq_len, num_states, num_heads, head_dim):
55-
try:
56-
v = torch.randn(seq_len, num_states, num_heads, head_dim).half().to("cuda:0")
57-
s = torch.randn(seq_len, num_states, num_heads, dtype=torch.float32).to(
58-
"cuda:0"
59-
)
60-
v_merged_std, s_merged_std = flashinfer.merge_states(v, s)
61-
v_merged, s_merged = flashinfer.triton.cascade.merge_states(v, s)
51+
v = torch.randn(seq_len, num_states, num_heads, head_dim).half().to("cuda:0")
52+
s = torch.randn(seq_len, num_states, num_heads, dtype=torch.float32).to("cuda:0")
53+
v_merged_std, s_merged_std = flashinfer.merge_states(v, s)
54+
v_merged, s_merged = flashinfer.triton.cascade.merge_states(v, s)
6255

63-
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
64-
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
65-
except GPUArchitectureError as e:
66-
pytest.skip(str(e))
56+
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
57+
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
6758

6859

60+
@skip_on_gpu_arch_error
6961
@pytest.mark.parametrize("seq_len", [2048])
7062
@pytest.mark.parametrize("num_heads", [32])
7163
@pytest.mark.parametrize("head_dim", [128])
7264
def test_variable_length_merge_states(seq_len, num_heads, head_dim):
73-
try:
74-
max_index_sets = 512
75-
lengths = torch.randint(low=1, high=max_index_sets, size=(seq_len,))
76-
indptr = [0]
77-
for i in range(seq_len):
78-
indptr.append(indptr[-1] + lengths[i])
79-
v = torch.randn(indptr[-1], num_heads, head_dim).half().to("cuda:0")
80-
s = torch.randn(indptr[-1], num_heads, dtype=torch.float32).to("cuda:0")
81-
indptr = torch.tensor(indptr, dtype=torch.int32).to("cuda:0")
82-
v_merged, s_merged = flashinfer.triton.cascade.variable_length_merge_states(
83-
v, s, indptr
84-
)
85-
for i in range(seq_len):
86-
sub_v = v[indptr[i] : indptr[i + 1]]
87-
sub_s = s[indptr[i] : indptr[i + 1]]
88-
sub_v = torch.unsqueeze(sub_v, 0)
89-
sub_s = torch.unsqueeze(sub_s, 0)
90-
v_merged_std, s_merged_std = flashinfer.merge_states(sub_v, sub_s)
91-
v_merged_std = torch.squeeze(v_merged_std, 0)
92-
s_merged_std = torch.squeeze(s_merged_std, 0)
93-
assert v_merged[i].shape == v_merged_std.shape
94-
assert torch.allclose(v_merged[i], v_merged_std, atol=1e-2)
95-
assert torch.allclose(s_merged[i], s_merged_std, atol=1e-2)
96-
except GPUArchitectureError as e:
97-
pytest.skip(str(e))
65+
max_index_sets = 512
66+
lengths = torch.randint(low=1, high=max_index_sets, size=(seq_len,))
67+
indptr = [0]
68+
for i in range(seq_len):
69+
indptr.append(indptr[-1] + lengths[i])
70+
v = torch.randn(indptr[-1], num_heads, head_dim).half().to("cuda:0")
71+
s = torch.randn(indptr[-1], num_heads, dtype=torch.float32).to("cuda:0")
72+
indptr = torch.tensor(indptr, dtype=torch.int32).to("cuda:0")
73+
v_merged, s_merged = flashinfer.triton.cascade.variable_length_merge_states(
74+
v, s, indptr
75+
)
76+
for i in range(seq_len):
77+
sub_v = v[indptr[i] : indptr[i + 1]]
78+
sub_s = s[indptr[i] : indptr[i + 1]]
79+
sub_v = torch.unsqueeze(sub_v, 0)
80+
sub_s = torch.unsqueeze(sub_s, 0)
81+
v_merged_std, s_merged_std = flashinfer.merge_states(sub_v, sub_s)
82+
v_merged_std = torch.squeeze(v_merged_std, 0)
83+
s_merged_std = torch.squeeze(s_merged_std, 0)
84+
assert v_merged[i].shape == v_merged_std.shape
85+
assert torch.allclose(v_merged[i], v_merged_std, atol=1e-2)
86+
assert torch.allclose(s_merged[i], s_merged_std, atol=1e-2)

0 commit comments

Comments
Β (0)