Skip to content

Commit 1ab6758

Browse files
committed
Changed CUDA setup to use PyTorch default; added a weak test.
1 parent ac155f7 commit 1ab6758

File tree

3 files changed

+69
-110
lines changed

3 files changed

+69
-110
lines changed

bitsandbytes/cextension.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@
3838

3939

4040
# print the setup details after checking for errors so we do not print twice
41-
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
42-
setup.print_log_stack()
41+
#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
42+
#setup.print_log_stack()

bitsandbytes/cuda_setup/main.py

Lines changed: 43 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ def generate_instructions(self):
4747
if getattr(self, 'error', False): return
4848
print(self.error)
4949
self.error = True
50-
if self.cuda is None:
51-
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.')
50+
if not self.cuda_available:
51+
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected or CUDA not installed.')
5252
self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.')
5353
self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:')
5454
self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null')
5555
self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a')
5656
self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc')
57+
self.add_log_entry('CUDA SETUP: Solution 3): For a missing CUDA runtime library (libcudart.so), use `find / -name libcudart.so* and follow with step (2b)')
5758
return
5859

5960
if self.cudart_path is None:
@@ -98,20 +99,26 @@ def initialize(self):
9899
self.initialized = False
99100
self.error = False
100101

102+
def manual_override(self):
103+
if torch.cuda.is_available():
104+
if 'CUDA_HOME' in os.environ and 'CUDA_VERSION' in os.environ:
105+
if len(os.environ['CUDA_HOME']) > 0 and len(os.environ['CUDA_VERSION']) > 0:
106+
self.binary_name = self.binary_name[:-6] + f'{os.environ["CUDA_VERSION"]}.so'
107+
101108
def run_cuda_setup(self):
102109
self.initialized = True
103110
self.cuda_setup_log = []
104111

105-
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
112+
binary_name, cudart_path, cc, cuda_version_string = evaluate_cuda_setup()
106113
self.cudart_path = cudart_path
107-
self.cuda = cuda
114+
self.cuda_available = torch.cuda.is_available()
108115
self.cc = cc
109116
self.cuda_version_string = cuda_version_string
117+
self.binary_name = binary_name
118+
self.manual_override()
110119

111120
package_dir = Path(__file__).parent.parent
112-
binary_path = package_dir / binary_name
113-
114-
print('bin', binary_path)
121+
binary_path = package_dir / self.binary_name
115122

116123
try:
117124
if not binary_path.exists():
@@ -123,10 +130,12 @@ def run_cuda_setup(self):
123130
self.add_log_entry('')
124131
self.add_log_entry('='*48 + 'ERROR' + '='*37)
125132
self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:')
126-
self.add_log_entry('1. CUDA driver not installed')
127-
self.add_log_entry('2. CUDA not installed')
128-
self.add_log_entry('3. You have multiple conflicting CUDA libraries')
129-
self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!')
133+
self.add_log_entry('1. You need to manually override the PyTorch CUDA version. Please see: '
134+
'"https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md')
135+
self.add_log_entry('2. CUDA driver not installed')
136+
self.add_log_entry('3. CUDA not installed')
137+
self.add_log_entry('4. You have multiple conflicting CUDA libraries')
138+
self.add_log_entry('5. Required library not pre-compiled for this bitsandbytes release!')
130139
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
131140
self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.')
132141
self.add_log_entry('='*80)
@@ -218,11 +227,13 @@ def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
218227
if len(results_paths) > 1:
219228
warning_msg = (
220229
f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. "
221-
"We'll flip a coin and try one of these, in order to fail forward.\n"
222-
"Either way, this might cause trouble in the future:\n"
223-
"If you get `CUDA error: invalid device function` errors, the above "
224-
"might be the cause and the solution is to make sure only one "
225-
f"{CUDA_RUNTIME_LIBS} in the paths that we search based on your env.")
230+
"We select the PyTorch default libcudart.so, which is {torch.version.cuda},"
231+
"but this might missmatch with the CUDA version that is needed for bitsandbytes."
232+
"To override this behavior set the CUDA_HOME environmental variable"
233+
"For example, if you want to use the CUDA version wht the path"
234+
"/usr/local/cuda-11.2/lib/libcudart.so as the default,"
235+
"then add the following to your .bashrc:"
236+
"export CUDA_HOME=/usr/local/cuda-11.2")
226237
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)
227238

228239

@@ -240,14 +251,15 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
240251
"""
241252
candidate_env_vars = get_potentially_lib_path_containing_env_vars()
242253

254+
cuda_runtime_libs = set()
243255
if "CONDA_PREFIX" in candidate_env_vars:
244256
conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib"
245257

246258
conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path))
247259
warn_in_case_of_duplicates(conda_cuda_libs)
248260

249261
if conda_cuda_libs:
250-
return next(iter(conda_cuda_libs))
262+
cuda_runtime_libs.update(conda_cuda_libs)
251263

252264
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
253265
f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True)
@@ -256,7 +268,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
256268
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
257269

258270
if lib_ld_cuda_libs:
259-
return next(iter(lib_ld_cuda_libs))
271+
cuda_runtime_libs.update(lib_ld_cuda_libs)
260272
warn_in_case_of_duplicates(lib_ld_cuda_libs)
261273

262274
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
@@ -277,33 +289,21 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
277289

278290
warn_in_case_of_duplicates(cuda_runtime_libs)
279291

292+
print(cuda_runtime_libs, flush=True)
293+
280294
return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None
281295

282296

283297
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
284-
def get_cuda_version(cuda, cudart_path):
285-
if cuda is None: return None
286-
298+
def get_cuda_version():
287299
major, minor = map(int, torch.version.cuda.split("."))
288300

289301
if major < 11:
290302
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
291303

292304
return f'{major}{minor}'
293305

294-
295-
def get_cuda_lib_handle():
296-
# 1. find libcuda.so library (GPU driver) (/usr/lib)
297-
try:
298-
cuda = ct.CDLL("libcuda.so")
299-
except OSError:
300-
CUDASetup.get_instance().add_log_entry('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
301-
return None
302-
303-
return cuda
304-
305-
306-
def get_compute_capabilities(cuda):
306+
def get_compute_capabilities():
307307
ccs = []
308308
for i in range(torch.cuda.device_count()):
309309
cc_major, cc_minor = torch.cuda.get_device_capability(torch.cuda.device(i))
@@ -312,20 +312,6 @@ def get_compute_capabilities(cuda):
312312
return ccs
313313

314314

315-
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
316-
def get_compute_capability(cuda):
317-
"""
318-
Extracts the highest compute capbility from all available GPUs, as compute
319-
capabilities are downwards compatible. If no GPUs are detected, it returns
320-
None.
321-
"""
322-
if cuda is None: return None
323-
324-
# TODO: handle different compute capabilities; for now, take the max
325-
ccs = get_compute_capabilities(cuda)
326-
if ccs: return ccs[-1]
327-
328-
329315
def evaluate_cuda_setup():
330316
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
331317
print('')
@@ -337,27 +323,15 @@ def evaluate_cuda_setup():
337323

338324
cuda_setup = CUDASetup.get_instance()
339325
cudart_path = determine_cuda_runtime_lib_path()
340-
cuda = get_cuda_lib_handle()
341-
cc = get_compute_capability(cuda)
342-
cuda_version_string = get_cuda_version(cuda, cudart_path)
343-
344-
failure = False
345-
if cudart_path is None:
346-
failure = True
347-
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
348-
else:
349-
cuda_setup.add_log_entry(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
326+
ccs = get_compute_capabilities()
327+
ccs.sort()
328+
cc = ccs[-1] # we take the highest capability
329+
cuda_version_string = get_cuda_version()
350330

351-
if cc == '' or cc is None:
352-
failure = True
353-
cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...", is_warning=True)
354-
else:
355-
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
331+
cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.")
332+
cuda_setup.add_log_entry(f"CUDA SETUP: To manually override the PyTorch CUDA version please see:"
333+
"https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md")
356334

357-
if cuda is None:
358-
failure = True
359-
else:
360-
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
361335

362336
# 7.5 is the minimum CC vor cublaslt
363337
has_cublaslt = is_cublasLt_compatible(cc)
@@ -369,12 +343,10 @@ def evaluate_cuda_setup():
369343
# we use ls -l instead of nvcc to determine the cuda version
370344
# since most installations will have the libcudart.so installed, but not the compiler
371345

372-
if failure:
373-
binary_name = "libbitsandbytes_cpu.so"
374-
elif has_cublaslt:
346+
if has_cublaslt:
375347
binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so"
376348
else:
377349
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
378350
binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so"
379351

380-
return binary_name, cudart_path, cuda, cc, cuda_version_string
352+
return binary_name, cudart_path, cc, cuda_version_string

tests/test_cuda_setup_evaluator.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,27 @@
11
import os
2-
from typing import List, NamedTuple
3-
42
import pytest
3+
import torch
4+
from pathlib import Path
5+
6+
# hardcoded test. Not good, but a sanity check for now
7+
def test_manual_override():
8+
manual_cuda_path = str(Path('/mmfs1/home/dettmers/data/local/cuda-12.2'))
9+
10+
pytorch_version = torch.version.cuda.replace('.', '')
11+
12+
assert pytorch_version != 122
13+
14+
os.environ['CUDA_HOME']='{manual_cuda_path}'
15+
os.environ['CUDA_VERSION']='122'
16+
assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH']
17+
import bitsandbytes as bnb
18+
loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name
19+
assert loaded_lib == 'libbitsandbytes_cuda122.so'
20+
21+
22+
23+
24+
25+
26+
527

6-
import bitsandbytes as bnb
7-
from bitsandbytes.cuda_setup.main import (
8-
determine_cuda_runtime_lib_path,
9-
evaluate_cuda_setup,
10-
extract_candidate_paths,
11-
)
12-
13-
14-
def test_cuda_full_system():
15-
## this only tests the cuda version and not compute capability
16-
17-
# if CONDA_PREFIX exists, it has priority before all other env variables
18-
# but it does not contain the library directly, so we need to look at the a sub-folder
19-
version = ""
20-
if "CONDA_PREFIX" in os.environ:
21-
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so.11.0')
22-
major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split("."))
23-
version = float(f"{major}.{minor}")
24-
25-
if version == "" and "LD_LIBRARY_PATH" in os.environ:
26-
ld_path = os.environ["LD_LIBRARY_PATH"]
27-
paths = ld_path.split(":")
28-
version = ""
29-
for p in paths:
30-
if "cuda" in p:
31-
idx = p.rfind("cuda-")
32-
version = p[idx + 5 : idx + 5 + 4].replace("/", "")
33-
version = float(version)
34-
break
35-
36-
37-
assert version > 0
38-
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
39-
binary_name = binary_name.replace("libbitsandbytes_cuda", "")
40-
assert binary_name.startswith(str(version).replace(".", ""))

0 commit comments

Comments
 (0)