1717"""
1818
1919import ctypes
20+ import torch
2021
2122from .paths import determine_cuda_runtime_lib_path
2223from bitsandbytes .cextension import CUDASetup
@@ -29,8 +30,11 @@ def check_cuda_result(cuda, result_val):
2930 cuda .cuGetErrorString (result_val , ctypes .byref (error_str ))
3031 CUDASetup .get_instance ().add_log_entry (f"CUDA exception! Error code: { error_str .value .decode ()} " )
3132
33+
34+ # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
3235def get_cuda_version (cuda , cudart_path ):
33- # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
36+ if cuda is None : return None
37+
3438 try :
3539 cudart = ctypes .CDLL (cudart_path )
3640 except OSError :
@@ -72,7 +76,6 @@ def get_compute_capabilities(cuda):
7276 # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
7377 """
7478
75-
7679 nGpus = ctypes .c_int ()
7780 cc_major = ctypes .c_int ()
7881 cc_minor = ctypes .c_int ()
@@ -99,11 +102,11 @@ def get_compute_capability(cuda):
99102 capabilities are downwards compatible. If no GPUs are detected, it returns
100103 None.
101104 """
105+ if cuda is None : return None
106+
107+ # TODO: handle different compute capabilities; for now, take the max
102108 ccs = get_compute_capabilities (cuda )
103- if ccs :
104- # TODO: handle different compute capabilities; for now, take the max
105- return ccs [- 1 ]
106- return None
109+ if ccs : return ccs [- 1 ]
107110
108111
109112def evaluate_cuda_setup ():
@@ -113,28 +116,31 @@ def evaluate_cuda_setup():
113116 #print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
114117 #print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
115118 #print('='*80)
116- #if not torch.cuda.is_available():
117- #print('No GPU detected. Loading CPU library...')
118- #return binary_name
119-
120- binary_name = "libbitsandbytes_cpu.so"
119+ if not torch .cuda .is_available (): return 'libsbitsandbytes_cpu.so' , None , None , None , None
121120
122121 cuda_setup = CUDASetup .get_instance ()
123122 cudart_path = determine_cuda_runtime_lib_path ()
124- if cudart_path is None :
125- cuda_setup .add_log_entry ("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!" , is_warning = True )
126- return binary_name
127-
128- cuda_setup .add_log_entry ((f"CUDA SETUP: CUDA runtime path found: { cudart_path } " ))
129123 cuda = get_cuda_lib_handle ()
130124 cc = get_compute_capability (cuda )
131- cuda_setup .add_log_entry (f"CUDA SETUP: Highest compute capability among GPUs detected: { cc } " )
132125 cuda_version_string = get_cuda_version (cuda , cudart_path )
133126
127+ failure = False
128+ if cudart_path is None :
129+ failure = True
130+ cuda_setup .add_log_entry ("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!" , is_warning = True )
131+ else :
132+ cuda_setup .add_log_entry ((f"CUDA SETUP: CUDA runtime path found: { cudart_path } " ))
134133
135134 if cc == '' or cc is None :
136- cuda_setup .add_log_entry ("WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..." , is_warning = True )
137- return binary_name , cudart_path , cuda , cc , cuda_version_string
135+ failure = True
136+ cuda_setup .add_log_entry ("WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library..." , is_warning = True )
137+ else :
138+ cuda_setup .add_log_entry (f"CUDA SETUP: Highest compute capability among GPUs detected: { cc } " )
139+
140+ if cuda is None :
141+ failure = True
142+ else :
143+ cuda_setup .add_log_entry (f'CUDA SETUP: Detected CUDA version { cuda_version_string } ' )
138144
139145 # 7.5 is the minimum CC vor cublaslt
140146 has_cublaslt = cc in ["7.5" , "8.0" , "8.6" ]
@@ -145,16 +151,13 @@ def evaluate_cuda_setup():
145151
146152 # we use ls -l instead of nvcc to determine the cuda version
147153 # since most installations will have the libcudart.so installed, but not the compiler
148- cuda_setup .add_log_entry (f'CUDA SETUP: Detected CUDA version { cuda_version_string } ' )
149154
150- def get_binary_name ():
155+ if failure :
156+ binary_name = "libbitsandbytes_cpu.so"
157+ elif has_cublaslt :
158+ binary_name = f"libbitsandbytes_cuda{ cuda_version_string } .so"
159+ else :
151160 "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
152- bin_base_name = "libbitsandbytes_cuda"
153- if has_cublaslt :
154- return f"{ bin_base_name } { cuda_version_string } .so"
155- else :
156- return f"{ bin_base_name } { cuda_version_string } _nocublaslt.so"
157-
158- binary_name = get_binary_name ()
161+ binary_name = f"libbitsandbytes_cuda{ cuda_version_string } _nocublaslt.so"
159162
160163 return binary_name , cudart_path , cuda , cc , cuda_version_string
0 commit comments