Skip to content

Commit ecd6648

Browse files
authored
Small fix on an exception (#1775)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 08ed2b2 commit ecd6648

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/test_triton_cascade.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_merge_state(seq_len, num_heads, head_dim):
2121
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
2222
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
2323
except GPUArchitectureError as e:
24-
pytest.skip(e.msg)
24+
pytest.skip(str(e))
2525

2626

2727
@pytest.mark.parametrize("seq_len", [2048])
@@ -44,7 +44,7 @@ def test_merge_state_in_place(seq_len, num_heads, head_dim):
4444
assert torch.allclose(s, s_std, atol=1e-2)
4545

4646
except GPUArchitectureError as e:
47-
pytest.skip(e.msg)
47+
pytest.skip(str(e))
4848

4949

5050
@pytest.mark.parametrize("seq_len", [2048])
@@ -63,7 +63,7 @@ def test_merge_states(seq_len, num_states, num_heads, head_dim):
6363
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
6464
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
6565
except GPUArchitectureError as e:
66-
pytest.skip(e.msg)
66+
pytest.skip(str(e))
6767

6868

6969
@pytest.mark.parametrize("seq_len", [2048])
@@ -94,4 +94,4 @@ def test_variable_length_merge_states(seq_len, num_heads, head_dim):
9494
assert torch.allclose(v_merged[i], v_merged_std, atol=1e-2)
9595
assert torch.allclose(s_merged[i], s_merged_std, atol=1e-2)
9696
except GPUArchitectureError as e:
97-
pytest.skip(e.msg)
97+
pytest.skip(str(e))

0 commit comments

Comments
 (0)