Skip to content

Commit 40df947

Browse files
authored
tests: upgrade cutlass, fix import and skip non-SM100 cutedsl two shot allreduce (#1812)
<!-- .github/pull_request_template.md --> ## 📌 Description upgrade cutlass-dsl python package to 4.2.1 to support distributed_helpers. Fix module level import problem and skip non-sm 100 cutedsl two shot allreduce unit tests. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent ba2b4aa commit 40df947

File tree

3 files changed

+7
-15
lines changed

3 files changed

+7
-15
lines changed

flashinfer/cute_dsl/gemm_allreduce_two_shot.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,6 @@
33
import torch
44
import torch.distributed as dist
55

6-
try:
7-
# cuda-python >= 12.9 (has cuda.bindings.driver)
8-
from cuda.bindings import driver as cuda
9-
except ImportError:
10-
try:
11-
# cuda-python < 12.9 (no cuda.bindings.driver, use cuda as driver)
12-
# from cuda import cuda is not available in cuda-python >= 13.0
13-
from cuda import cuda
14-
except ImportError as e:
15-
raise ImportError(
16-
"Could not import the 'cuda' module. "
17-
"Please install cuda-python that matches your CUDA version."
18-
) from e
19-
206
import cutlass
217
import cutlass.cute as cute
228
import cutlass.utils as utils
@@ -380,7 +366,7 @@ def __call__(
380366
b: cute.Tensor,
381367
c: cute.Tensor,
382368
max_active_clusters: cutlass.Constexpr,
383-
stream: cuda.CUstream,
369+
stream,
384370
epilogue_op: cutlass.Constexpr = lambda x: x,
385371
c_mc: cute.Tensor = None,
386372
barrier_flag: cute.Tensor = None,

setup.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def generate_build_meta(aot_build_meta: dict) -> None:
9494
"apache-tvm-ffi==0.1.0b11",
9595
"packaging>=24.2",
9696
"nvidia-cudnn-frontend>=1.13.0",
97+
"nvidia-cutlass-dsl>=4.2.1",
9798
]
9899
generate_build_meta({})
99100

tests/unlisted/test_cute_dsl_gemm_allreduce_two_shot.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import torch.distributed._symmetric_memory as symm_mem
3131

3232
from flashinfer.cute_dsl.gemm_allreduce_two_shot import PersistentDenseGemmKernel
33+
from flashinfer.utils import get_compute_capability
3334

3435

3536
logger = logging.getLogger(__name__)
@@ -482,6 +483,10 @@ def test_cute_dsl_gemm_allreduce_two_shot(world_size):
482483
pytest.skip(
483484
f"world_size {world_size} is greater than available_gpus {available_gpus}"
484485
)
486+
487+
if get_compute_capability(torch.device("cuda")) != (10, 0):
488+
pytest.skip("cute_dsl_gemm_allreduce_two_shot requires SM100")
489+
485490
print(f"Running test for world_size={world_size}")
486491
multi_process_parallel(
487492
world_size,

0 commit comments

Comments
 (0)