Skip to content

Commit 431819d

Browse files
Stub out for multi-platform
1 parent c703d8d commit 431819d

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

bitsandbytes/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,34 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7+
import torch
8+
79
from . import _ops, research, utils
810
from .autograd._functions import (
911
MatmulLtState,
1012
matmul,
1113
matmul_4bit,
1214
)
1315
from .backends.cpu import ops as cpu_ops
14-
from .backends.cuda import ops as cuda_ops ## TODO: We would guard this for CUDA only
1516
from .backends.default import ops as default_ops
1617
from .nn import modules
1718
from .optim import adam
1819

20+
# This is a signal for integrations with transformers/diffusers.
21+
# Eventually, we will remove this and check based on release version.
22+
features = {"multi-backend"}
23+
supported_torch_devices = {
24+
"cuda",
25+
"cpu",
26+
# "mps",
27+
# "xpu",
28+
# "hpu",
29+
# "npu",
30+
}
31+
32+
if torch.cuda.is_available():
33+
from .backends.cuda import ops as cuda_ops
34+
1935
__pdoc__ = {
2036
"libbitsandbytes": False,
2137
"optim.optimizer.Optimizer8bit": False,

bitsandbytes/cextension.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
1919
2020
The library is not guaranteed to exist at the returned path.
2121
"""
22-
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
22+
23+
prefix = "rocm" if torch.version.hip else "cuda"
24+
library_name = f"libbitsandbytes_{prefix}{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
2325

2426
override_value = os.environ.get("BNB_CUDA_VERSION")
2527
if override_value:
@@ -76,7 +78,7 @@ def get_native_library() -> BNBNativeLibrary:
7678

7779
logger.warning(
7880
"The installed version of bitsandbytes was compiled without GPU support. "
79-
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.",
81+
"8-bit optimizers and GPU quantization are unavailable.",
8082
)
8183
return BNBNativeLibrary(dll)
8284

bitsandbytes/cuda_specs.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import dataclasses
2+
from functools import lru_cache
23
from typing import List, Optional, Tuple
34

45
import torch
@@ -12,22 +13,26 @@ class CUDASpecs:
1213

1314
@property
1415
def has_imma(self) -> bool:
15-
return self.highest_compute_capability >= (7, 5)
16+
return torch.version.hip or self.highest_compute_capability >= (7, 5)
1617

1718

1819
def get_compute_capabilities() -> List[Tuple[int, int]]:
1920
return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()))
2021

2122

23+
@lru_cache(None)
2224
def get_cuda_version_tuple() -> Tuple[int, int]:
23-
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
24-
major, minor = map(int, torch.version.cuda.split("."))
25-
return major, minor
25+
if torch.version.cuda:
26+
return map(int, torch.version.cuda.split(".")[0:2])
27+
elif torch.version.hip:
28+
return map(int, torch.version.hip.split(".")[0:2])
29+
30+
return None
2631

2732

2833
def get_cuda_version_string() -> str:
2934
major, minor = get_cuda_version_tuple()
30-
return f"{major}{minor}"
35+
return f"{major * 10 + minor}"
3136

3237

3338
def get_cuda_specs() -> Optional[CUDASpecs]:

0 commit comments

Comments
 (0)