Skip to content

Commit 91e6140

Browse files
authored
xfail the cute dsl tests for l=1 (#1868)
<!-- .github/pull_request_template.md --> ## 📌 Description With the latest version of nvidia-cutlass-dsl, `mark_layout_dynamic` may throw errors like ``` > self._dltensor_wrapper.mark_layout_dynamic(leading_dim) E RuntimeError: Expected strides[leading_dim] == 1, but got 7340032. ``` when calling `cutlass.torch.cute_tensor_like` and `l = 1` in gemm problem size. This issue has been reported to [cutlass#2673](NVIDIA/cutlass#2673). So this PR marks the cute dsl blockscaled gemm tests where `l = 1`, due to the issue of nvidia-cutlass-dsl. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 83af9a8 commit 91e6140

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tests/gemm/test_cute_dsl_blockscaled_gemm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from flashinfer.cute_dsl.utils import (
2121
get_cutlass_dtype,
22+
get_num_sm,
2223
is_cute_dsl_available,
2324
)
2425

@@ -56,7 +57,6 @@
5657
@pytest.mark.parametrize("alpha_dtype", ["float32"])
5758
@pytest.mark.parametrize("mma_tiler_mn", [(128, 128)])
5859
@pytest.mark.parametrize("cluster_shape_mn", [(1, 1)])
59-
@pytest.mark.parametrize("sm_count", [132, None])
6060
@pytest.mark.parametrize("tolerance", [1e-01])
6161
@pytest.mark.parametrize("iterations", [3])
6262
@pytest.mark.parametrize("enable_dst_signals", [False, True])
@@ -74,7 +74,6 @@ def test_blockscaled_gemm_python_interface(
7474
alpha_dtype: cutlass.dtype,
7575
mma_tiler_mn: Tuple[int, int],
7676
cluster_shape_mn: Tuple[int, int],
77-
sm_count: int,
7877
tolerance: float,
7978
iterations: int,
8079
enable_dst_signals: int,
@@ -85,11 +84,13 @@ def test_blockscaled_gemm_python_interface(
8584

8685
if not (major == 10 and minor == 0):
8786
pytest.skip("Cute-dsl backend is only supported on SM100.")
88-
if enable_dst_signals and (sm_count is None):
89-
pytest.skip("dst_signals require sm_count")
9087

9188
l, m = lm
9289
k, n = kn
90+
if l == 1:
91+
pytest.xfail("nvidia-cutlass-dsl has issue when l=1")
92+
93+
sm_count = get_num_sm(device) if enable_dst_signals else None
9394

9495
print(f"device: {device}")
9596

0 commit comments

Comments
 (0)