Skip to content

Commit 2b4cc25

Browse files
committed
fix: Get device's compute capability
1 parent f511026 commit 2b4cc25

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

bitsandbytes/cuda_setup/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
279279
def get_cuda_version(cuda, cudart_path):
280280
if cuda is None: return None
281281

282-
version = torch._C._cuda_getCompiledVersion()
283-
major = version//1000
284-
minor = (version-(major*1000))//10
282+
major, minor = map(int, torch.version.cuda.split("."))
285283

286284
if major < 11:
287285
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
@@ -302,7 +300,9 @@ def get_cuda_lib_handle():
302300

303301
def get_compute_capabilities(cuda):
304302
ccs = []
305-
ccs.append(torch.version.cuda)
303+
for i in range(torch.cuda.device_count()):
304+
cc_major, cc_minor = torch.cuda.get_device_capability(torch.cuda.device(i))
305+
ccs.append(f"{cc_major}.{cc_minor}")
306306

307307
return ccs
308308

0 commit comments

Comments
 (0)