Skip to content

Commit feaedbb

Browse files
authored
fix: Improve CUDA version detection and error handling (#1599)
* fix: Improve CUDA version detection and error handling * lint fix * lint fix
1 parent b982796 commit feaedbb

File tree

1 file changed

+42
-13
lines changed

1 file changed

+42
-13
lines changed

bitsandbytes/cuda_specs.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,55 @@ def get_compute_capabilities() -> list[tuple[int, int]]:
2121

2222

2323
@lru_cache(None)
24-
def get_cuda_version_tuple() -> tuple[int, int]:
25-
if torch.version.cuda:
26-
return tuple(map(int, torch.version.cuda.split(".")[0:2]))
27-
elif torch.version.hip:
28-
return tuple(map(int, torch.version.hip.split(".")[0:2]))
24+
def get_cuda_version_tuple() -> Optional[tuple[int, int]]:
25+
"""Get CUDA/HIP version as a tuple of (major, minor)."""
26+
try:
27+
if torch.version.cuda:
28+
version_str = torch.version.cuda
29+
elif torch.version.hip:
30+
version_str = torch.version.hip
31+
else:
32+
return None
2933

30-
return None
34+
parts = version_str.split(".")
35+
if len(parts) >= 2:
36+
return tuple(map(int, parts[:2]))
37+
return None
38+
except (AttributeError, ValueError, IndexError):
39+
return None
3140

3241

33-
def get_cuda_version_string() -> str:
34-
major, minor = get_cuda_version_tuple()
42+
def get_cuda_version_string() -> Optional[str]:
43+
"""Get CUDA/HIP version as a string."""
44+
version_tuple = get_cuda_version_tuple()
45+
if version_tuple is None:
46+
return None
47+
major, minor = version_tuple
3548
return f"{major * 10 + minor}"
3649

3750

3851
def get_cuda_specs() -> Optional[CUDASpecs]:
52+
"""Get CUDA/HIP specifications."""
3953
if not torch.cuda.is_available():
4054
return None
4155

42-
return CUDASpecs(
43-
highest_compute_capability=(get_compute_capabilities()[-1]),
44-
cuda_version_string=(get_cuda_version_string()),
45-
cuda_version_tuple=get_cuda_version_tuple(),
46-
)
56+
try:
57+
compute_capabilities = get_compute_capabilities()
58+
if not compute_capabilities:
59+
return None
60+
61+
version_tuple = get_cuda_version_tuple()
62+
if version_tuple is None:
63+
return None
64+
65+
version_string = get_cuda_version_string()
66+
if version_string is None:
67+
return None
68+
69+
return CUDASpecs(
70+
highest_compute_capability=compute_capabilities[-1],
71+
cuda_version_string=version_string,
72+
cuda_version_tuple=version_tuple,
73+
)
74+
except Exception:
75+
return None

0 commit comments

Comments
 (0)