Skip to content

Commit e8df8d6

Browse files
authored
Merge pull request #375 from rapsealk/fix/libcuda-to-torch
Replace libcudart.so with PyTorch's CUDA APIs
2 parents 6689afa + a24aae3 commit e8df8d6

File tree

1 file changed

+4
-52
lines changed

1 file changed

+4
-52
lines changed

bitsandbytes/cuda_setup/main.py

Lines changed: 4 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -279,37 +279,11 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
279279
return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None
280280

281281

282-
def check_cuda_result(cuda, result_val):
283-
# 3. Check for CUDA errors
284-
if result_val != 0:
285-
error_str = ct.c_char_p()
286-
cuda.cuGetErrorString(result_val, ct.byref(error_str))
287-
if error_str.value is not None:
288-
CUDASetup.get_instance().add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}")
289-
else:
290-
CUDASetup.get_instance().add_log_entry(f"Unknown CUDA exception! Please check your CUDA install. It might also be that your GPU is too old.")
291-
292-
293282
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
294283
def get_cuda_version(cuda, cudart_path):
295284
if cuda is None: return None
296285

297-
try:
298-
cudart = ct.CDLL(cudart_path)
299-
except OSError:
300-
CUDASetup.get_instance().add_log_entry(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
301-
return None
302-
303-
version = ct.c_int()
304-
try:
305-
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ct.byref(version)))
306-
except AttributeError as e:
307-
CUDASetup.get_instance().add_log_entry(f'ERROR: {str(e)}')
308-
CUDASetup.get_instance().add_log_entry(f'CUDA SETUP: libcudart.so path is {cudart_path}')
309-
CUDASetup.get_instance().add_log_entry(f'CUDA SETUP: Is seems that your cuda installation is not in your path. See https://github.com/TimDettmers/bitsandbytes/issues/85 for more information.')
310-
version = int(version.value)
311-
major = version//1000
312-
minor = (version-(major*1000))//10
286+
major, minor = map(int, torch.version.cuda.split("."))
313287

314288
if major < 11:
315289
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
@@ -324,37 +298,15 @@ def get_cuda_lib_handle():
324298
except OSError:
325299
CUDASetup.get_instance().add_log_entry('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
326300
return None
327-
check_cuda_result(cuda, cuda.cuInit(0))
328301

329302
return cuda
330303

331304

332305
def get_compute_capabilities(cuda):
333-
"""
334-
1. find libcuda.so library (GPU driver) (/usr/lib)
335-
init_device -> init variables -> call function by reference
336-
2. call extern C function to determine CC
337-
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
338-
3. Check for CUDA errors
339-
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
340-
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
341-
"""
342-
343-
nGpus = ct.c_int()
344-
cc_major = ct.c_int()
345-
cc_minor = ct.c_int()
346-
347-
device = ct.c_int()
348-
349-
check_cuda_result(cuda, cuda.cuDeviceGetCount(ct.byref(nGpus)))
350306
ccs = []
351-
for i in range(nGpus.value):
352-
check_cuda_result(cuda, cuda.cuDeviceGet(ct.byref(device), i))
353-
ref_major = ct.byref(cc_major)
354-
ref_minor = ct.byref(cc_minor)
355-
# 2. call extern C function to determine CC
356-
check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device))
357-
ccs.append(f"{cc_major.value}.{cc_minor.value}")
307+
for i in range(torch.cuda.device_count()):
308+
cc_major, cc_minor = torch.cuda.get_device_capability(torch.cuda.device(i))
309+
ccs.append(f"{cc_major}.{cc_minor}")
358310

359311
return ccs
360312

0 commit comments

Comments
 (0)