Skip to content

Commit 55f4c39

Browse files
committed
Polished CUDA SETUP replacement and added docs.
1 parent 1ab6758 commit 55f4c39

File tree

2 files changed

+72
-15
lines changed

2 files changed

+72
-15
lines changed

bitsandbytes/cuda_setup/main.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,16 @@ def initialize(self):
101101

102102
def manual_override(self):
103103
if torch.cuda.is_available():
104-
if 'CUDA_HOME' in os.environ and 'CUDA_VERSION' in os.environ:
105-
if len(os.environ['CUDA_HOME']) > 0 and len(os.environ['CUDA_VERSION']) > 0:
104+
if 'CUDA_VERSION' in os.environ:
105+
if len(os.environ['CUDA_VERSION']) > 0:
106+
warn((f'\n\n{"="*80}\n'
107+
'WARNING: Manual override via CUDA_VERSION env variable detected!\n'
108+
'CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n'
109+
'If this was unintended set the CUDA_VERSION variable to an empty string: export CUDA_VERSION=\n'
110+
'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n'
111+
'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64\n'
112+
f'Loading CUDA version: CUDA_VERSION={os.environ["CUDA_VERSION"]}'
113+
f'\n{"="*80}\n\n'))
106114
self.binary_name = self.binary_name[:-6] + f'{os.environ["CUDA_VERSION"]}.so'
107115

108116
def run_cuda_setup(self):
@@ -194,8 +202,8 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
194202

195203
non_existent_directories: Set[Path] = candidate_paths - existent_directories
196204
if non_existent_directories:
197-
CUDASetup.get_instance().add_log_entry("WARNING: The following directories listed in your path were found to "
198-
f"be non-existent: {non_existent_directories}", is_warning=True)
205+
CUDASetup.get_instance().add_log_entry("The following directories listed in your path were found to "
206+
f"be non-existent: {non_existent_directories}", is_warning=False)
199207

200208
return existent_directories
201209

@@ -229,11 +237,12 @@ def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
229237
f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. "
230238
"We select the PyTorch default libcudart.so, which is {torch.version.cuda},"
231239
"but this might missmatch with the CUDA version that is needed for bitsandbytes."
232-
"To override this behavior set the CUDA_HOME environmental variable"
233-
"For example, if you want to use the CUDA version wht the path"
234-
"/usr/local/cuda-11.2/lib/libcudart.so as the default,"
235-
"then add the following to your .bashrc:"
236-
"export CUDA_HOME=/usr/local/cuda-11.2")
240+
"To override this behavior set the CUDA_VERSION=<version string, e.g. 122> environmental variable"
241+
"For example, if you want to use the CUDA version 122"
242+
"CUDA_VERSION=122 python ..."
243+
"OR set the environmental variable in your .bashrc: export CUDA_VERSION=122"
244+
"In the case of a manual override, make sure you set the LD_LIBRARY_PATH, e.g."
245+
"export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2")
237246
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)
238247

239248

@@ -289,7 +298,8 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
289298

290299
warn_in_case_of_duplicates(cuda_runtime_libs)
291300

292-
print(cuda_runtime_libs, flush=True)
301+
cuda_setup = CUDASetup.get_instance()
302+
cuda_setup.add_log_entry(f'DEBUG: Possible options found for libcudart.so: {cuda_runtime_libs}')
293303

294304
return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None
295305

@@ -313,15 +323,15 @@ def get_compute_capabilities():
313323

314324

315325
def evaluate_cuda_setup():
326+
cuda_setup = CUDASetup.get_instance()
316327
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
317-
print('')
318-
print('='*35 + 'BUG REPORT' + '='*35)
319-
print(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'),
328+
cuda_setup.add_log_entry('')
329+
cuda_setup.add_log_entry('='*35 + 'BUG REPORT' + '='*35)
330+
cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'),
320331
('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues'))
321-
print('='*80)
332+
cuda_setup.add_log_entry('='*80)
322333
if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None, None
323334

324-
cuda_setup = CUDASetup.get_instance()
325335
cudart_path = determine_cuda_runtime_lib_path()
326336
ccs = get_compute_capabilities()
327337
ccs.sort()

how_to_use_nonpytorch_cuda.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
## How to use a CUDA version that is different from PyTorch
2+
3+
Some features of bitsandbytes may need a newer CUDA version than regularly supported by PyTorch binaries from conda / pip. In that case you can use the following instructions to load a precompiled bitsandbytes binary that works for you.
4+
5+
## Installing or determining the CUDA installation
6+
7+
Determine the path of the CUDA version that you want to use. Common paths paths are:
8+
```bash
9+
/usr/local/cuda
10+
/usr/local/cuda-XX.X
11+
```
12+
13+
where XX.X is the CUDA version number.
14+
15+
You can also install CUDA version that you need locally with a script provided by bitsandbytes as follows:
16+
17+
```bash
18+
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh
19+
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
20+
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122}
21+
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
22+
23+
# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc
24+
bash cuda install 117 ~/local 1
25+
```
26+
27+
## Setting the environmental variables CUDA_HOME, CUDA_VERSION, and LD_LIBRARY_PATH
28+
29+
To manually override the PyTorch installed CUDA version you need to set to variable, like so:
30+
31+
```bash
32+
export CUDA_HOME=<PATH>
33+
export CUDA_VERSION=<VERSION>
34+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<PATH>
35+
```
36+
37+
For example, to use the local install path from above:
38+
39+
```bash
40+
export CUDA_HOME=/home/tim/local/cuda-11.7
41+
export CUDA_VERSION=117
42+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/tim/local/cuda-11.7
43+
```
44+
45+
It is best to add these lines to the `.bashrc` file to make them permanent.
46+
47+
If you now launch bitsandbytes with these environmental variables the PyTorch CUDA version will be overridden by the new CUDA version and a different bitsandbytes library is loaded (in this case version 117).

0 commit comments

Comments
 (0)