11import dataclasses
22from functools import lru_cache
3- from typing import Optional
3+ from typing import Optional , Tuple
44
55import torch
66
@@ -21,26 +21,55 @@ def get_compute_capabilities() -> list[tuple[int, int]]:
2121
2222
2323@lru_cache (None )
24- def get_cuda_version_tuple () -> tuple [int , int ]:
25- if torch .version .cuda :
26- return tuple (map (int , torch .version .cuda .split ("." )[0 :2 ]))
27- elif torch .version .hip :
28- return tuple (map (int , torch .version .hip .split ("." )[0 :2 ]))
24+ def get_cuda_version_tuple () -> Optional [Tuple [int , int ]]:
25+ """Get CUDA/HIP version as a tuple of (major, minor)."""
26+ try :
27+ if torch .version .cuda :
28+ version_str = torch .version .cuda
29+ elif torch .version .hip :
30+ version_str = torch .version .hip
31+ else :
32+ return None
2933
30- return None
34+ parts = version_str .split ("." )
35+ if len (parts ) >= 2 :
36+ return tuple (map (int , parts [:2 ]))
37+ return None
38+ except (AttributeError , ValueError , IndexError ):
39+ return None
3140
3241
33- def get_cuda_version_string () -> str :
34- major , minor = get_cuda_version_tuple ()
42+ def get_cuda_version_string () -> Optional [str ]:
43+ """Get CUDA/HIP version as a string."""
44+ version_tuple = get_cuda_version_tuple ()
45+ if version_tuple is None :
46+ return None
47+ major , minor = version_tuple
3548 return f"{ major * 10 + minor } "
3649
3750
3851def get_cuda_specs () -> Optional [CUDASpecs ]:
52+ """Get CUDA/HIP specifications."""
3953 if not torch .cuda .is_available ():
4054 return None
4155
42- return CUDASpecs (
43- highest_compute_capability = (get_compute_capabilities ()[- 1 ]),
44- cuda_version_string = (get_cuda_version_string ()),
45- cuda_version_tuple = get_cuda_version_tuple (),
46- )
56+ try :
57+ compute_capabilities = get_compute_capabilities ()
58+ if not compute_capabilities :
59+ return None
60+
61+ version_tuple = get_cuda_version_tuple ()
62+ if version_tuple is None :
63+ return None
64+
65+ version_string = get_cuda_version_string ()
66+ if version_string is None :
67+ return None
68+
69+ return CUDASpecs (
70+ highest_compute_capability = compute_capabilities [- 1 ],
71+ cuda_version_string = version_string ,
72+ cuda_version_tuple = version_tuple ,
73+ )
74+ except Exception :
75+ return None
0 commit comments