Skip to content

Commit 3475e4b

Browse files
authored
fix: fix cannot import name 'cuda' from 'cuda' in CUDA13 (#1764)
1 parent 84aaa3d commit 3475e4b

File tree

5 files changed

+47
-9
lines changed

5 files changed

+47
-9
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,18 @@
2525
import torch
2626

2727
try:
28-
from cuda import cuda
29-
except ImportError as e:
30-
raise ImportError(
31-
"Could not import the 'cuda' module. "
32-
"Please install cuda-python that matches your CUDA version."
33-
) from e
28+
# cuda-python >= 12.9 (has cuda.bindings.driver)
29+
from cuda.bindings import driver as cuda
30+
except ImportError:
31+
try:
32+
# cuda-python < 12.9 (no cuda.bindings.driver, use cuda as driver)
33+
# from cuda import cuda is not available in cuda-python >= 13.0
34+
from cuda import cuda
35+
except ImportError as e:
36+
raise ImportError(
37+
"Could not import the 'cuda' module. "
38+
"Please install cuda-python that matches your CUDA version."
39+
) from e
3440

3541
from ..cuda_utils import checkCudaErrors
3642
from .dlpack_utils import create_dlpack_capsule, pack_strided_memory

flashinfer/cuda_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
try:
2020
# Check if cuda.cudart module is available and import accordingly
2121
if has_cuda_cudart():
22-
# cuda-python <= 12.9 (has cuda.cudart)
22+
# cuda-python < 12.9 (has cuda.cudart)
2323
import cuda.bindings.driver as driver
2424
import cuda.bindings.runtime as runtime
2525
import cuda.cudart as cudart

flashinfer/cute_dsl/gemm_allreduce_two_shot.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,20 @@
22

33
import torch
44
import torch.distributed as dist
5-
from cuda import cuda
5+
6+
try:
7+
# cuda-python >= 12.9 (has cuda.bindings.driver)
8+
from cuda.bindings import driver as cuda
9+
except ImportError:
10+
try:
11+
# cuda-python < 12.9 (no cuda.bindings.driver, use cuda as driver)
12+
# from cuda import cuda is not available in cuda-python >= 13.0
13+
from cuda import cuda
14+
except ImportError as e:
15+
raise ImportError(
16+
"Could not import the 'cuda' module. "
17+
"Please install cuda-python that matches your CUDA version."
18+
) from e
619

720
import cutlass
821
import cutlass.cute as cute

flashinfer/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,12 @@ def has_cuda_cudart() -> bool:
453453
return importlib.util.find_spec("cuda.cudart") is not None
454454

455455

456+
def get_cuda_python_version() -> str:
457+
import cuda
458+
459+
return cuda.__version__
460+
461+
456462
def is_sm90a_supported(device: torch.device) -> bool:
457463
major, _ = get_compute_capability(device)
458464
return major == 9 and version_at_least(torch.version.cuda, "12.3")

tests/test_cute_dsl_gemm_allreduce_two_shot.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,20 @@
44
import socket
55
from typing import Any, Tuple, Type
66

7-
from cuda import cuda
7+
try:
8+
# cuda-python >= 12.9 (has cuda.bindings.driver)
9+
from cuda.bindings import driver as cuda
10+
except ImportError:
11+
try:
12+
# cuda-python <= 12.9 (no cuda.bindings.driver, use cuda as driver)
13+
# from cuda import cuda is not available in cuda-python >= 13.0
14+
from cuda import cuda
15+
except ImportError as e:
16+
raise ImportError(
17+
"Could not import the 'cuda' module. "
18+
"Please install cuda-python that matches your CUDA version."
19+
) from e
20+
821
import cutlass
922
import cutlass.cute as cute
1023
import cutlass.cute.testing as testing

0 commit comments

Comments
 (0)