File tree Expand file tree Collapse file tree 5 files changed +47
-9
lines changed Expand file tree Collapse file tree 5 files changed +47
-9
lines changed Original file line number Diff line number Diff line change 25
25
import torch
26
26
27
27
try :
28
- from cuda import cuda
29
- except ImportError as e :
30
- raise ImportError (
31
- "Could not import the 'cuda' module. "
32
- "Please install cuda-python that matches your CUDA version."
33
- ) from e
28
+ # cuda-python >= 12.9 (has cuda.bindings.driver)
29
+ from cuda .bindings import driver as cuda
30
+ except ImportError :
31
+ try :
32
+ # cuda-python < 12.9 (no cuda.bindings.driver, use cuda as driver)
33
+ # from cuda import cuda is not available in cuda-python >= 13.0
34
+ from cuda import cuda
35
+ except ImportError as e :
36
+ raise ImportError (
37
+ "Could not import the 'cuda' module. "
38
+ "Please install cuda-python that matches your CUDA version."
39
+ ) from e
34
40
35
41
from ..cuda_utils import checkCudaErrors
36
42
from .dlpack_utils import create_dlpack_capsule , pack_strided_memory
Original file line number Diff line number Diff line change 19
19
try :
20
20
# Check if cuda.cudart module is available and import accordingly
21
21
if has_cuda_cudart ():
22
- # cuda-python <= 12.9 (has cuda.cudart)
22
+ # cuda-python < 12.9 (has cuda.cudart)
23
23
import cuda .bindings .driver as driver
24
24
import cuda .bindings .runtime as runtime
25
25
import cuda .cudart as cudart
Original file line number Diff line number Diff line change 2
2
3
3
import torch
4
4
import torch .distributed as dist
5
- from cuda import cuda
5
+
6
+ try :
7
+ # cuda-python >= 12.9 (has cuda.bindings.driver)
8
+ from cuda .bindings import driver as cuda
9
+ except ImportError :
10
+ try :
11
+ # cuda-python < 12.9 (no cuda.bindings.driver, use cuda as driver)
12
+ # from cuda import cuda is not available in cuda-python >= 13.0
13
+ from cuda import cuda
14
+ except ImportError as e :
15
+ raise ImportError (
16
+ "Could not import the 'cuda' module. "
17
+ "Please install cuda-python that matches your CUDA version."
18
+ ) from e
6
19
7
20
import cutlass
8
21
import cutlass .cute as cute
Original file line number Diff line number Diff line change @@ -453,6 +453,12 @@ def has_cuda_cudart() -> bool:
453
453
return importlib .util .find_spec ("cuda.cudart" ) is not None
454
454
455
455
456
+ def get_cuda_python_version () -> str :
457
+ import cuda
458
+
459
+ return cuda .__version__
460
+
461
+
456
462
def is_sm90a_supported (device : torch .device ) -> bool :
457
463
major , _ = get_compute_capability (device )
458
464
return major == 9 and version_at_least (torch .version .cuda , "12.3" )
Original file line number Diff line number Diff line change 4
4
import socket
5
5
from typing import Any , Tuple , Type
6
6
7
- from cuda import cuda
7
+ try :
8
+ # cuda-python >= 12.9 (has cuda.bindings.driver)
9
+ from cuda .bindings import driver as cuda
10
+ except ImportError :
11
+ try :
12
+ # cuda-python <= 12.9 (no cuda.bindings.driver, use cuda as driver)
13
+ # from cuda import cuda is not available in cuda-python >= 13.0
14
+ from cuda import cuda
15
+ except ImportError as e :
16
+ raise ImportError (
17
+ "Could not import the 'cuda' module. "
18
+ "Please install cuda-python that matches your CUDA version."
19
+ ) from e
20
+
8
21
import cutlass
9
22
import cutlass .cute as cute
10
23
import cutlass .cute .testing as testing
You can’t perform that action at this time.
0 commit comments