Skip to content

Commit 7a6563b

Browse files
committed
Default to CPU library on CUDA error+small refactor.
1 parent d9112dc commit 7a6563b

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

bitsandbytes/cuda_setup/main.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def generate_instructions(self):
6363
elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0:
6464
make_cmd += ' make cuda11x'
6565

66-
has_cublaslt = self.cc in ["7.5", "8.0", "8.6"]
66+
has_cublaslt = is_cublasLt_compatible(self.cc)
6767
if not has_cublaslt:
6868
make_cmd += '_nomatmul'
6969

@@ -94,7 +94,7 @@ def run_cuda_setup(self):
9494
try:
9595
if not binary_path.exists():
9696
self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?")
97-
legacy_binary_name = "libbitsandbytes.so"
97+
legacy_binary_name = "libbitsandbytes_cpu.so"
9898
self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
9999
binary_path = package_dir / legacy_binary_name
100100
if not binary_path.exists():
@@ -137,6 +137,15 @@ def get_instance(cls):
137137
return cls._instance
138138

139139

140+
def is_cublasLt_compatible(cc):
141+
has_cublaslt = False
142+
if cc is not None:
143+
cc_major, cc_minor = cc.split('.')
144+
if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5):
145+
cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Proceeding to load CPU-only library...", is_warning=True)
146+
else:
147+
has_cublaslt = True
148+
return has_cublaslt
140149

141150
def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
142151
return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}
@@ -368,13 +377,7 @@ def evaluate_cuda_setup():
368377
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
369378

370379
# 7.5 is the minimum CC vor cublaslt
371-
has_cublaslt = False
372-
if cc is not None:
373-
cc_major, cc_minor = cc.split('.')
374-
if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5):
375-
cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Proceeding to load CPU-only library...", is_warning=True)
376-
else:
377-
has_cublaslt = True
380+
has_cublaslt = is_cublasLt_compatible(cc)
378381

379382
# TODO:
380383
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)

0 commit comments

Comments
 (0)