Skip to content

Commit 97b2567

Browse files
committed
fix: Replace libcudart with pytorch api
1 parent 659a7df commit 97b2567

File tree

1 file changed

+3
-24
lines changed

1 file changed

+3
-24
lines changed

bitsandbytes/cuda_setup/main.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -326,31 +326,10 @@ def get_cuda_lib_handle():
326326

327327

328328
def get_compute_capabilities(cuda):
329-
"""
330-
1. find libcuda.so library (GPU driver) (/usr/lib)
331-
init_device -> init variables -> call function by reference
332-
2. call extern C function to determine CC
333-
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
334-
3. Check for CUDA errors
335-
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
336-
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
337-
"""
338-
339-
nGpus = ct.c_int()
340-
cc_major = ct.c_int()
341-
cc_minor = ct.c_int()
342-
343-
device = ct.c_int()
344-
345-
check_cuda_result(cuda, cuda.cuDeviceGetCount(ct.byref(nGpus)))
346329
ccs = []
347-
for i in range(nGpus.value):
348-
check_cuda_result(cuda, cuda.cuDeviceGet(ct.byref(device), i))
349-
ref_major = ct.byref(cc_major)
350-
ref_minor = ct.byref(cc_minor)
351-
# 2. call extern C function to determine CC
352-
check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device))
353-
ccs.append(f"{cc_major.value}.{cc_minor.value}")
330+
# for i in range(torch.cuda.device_count()):
331+
# device = torch.cuda.device(i)
332+
ccs.append(torch.version.cuda)
354333

355334
return ccs
356335

0 commit comments

Comments
 (0)