Skip to content

Commit 90c65fb

Browse files
Include device support tags for transformers multi-backend compatability; add xpu() and cpu() to Params4bit
1 parent b982796 commit 90c65fb

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

bitsandbytes/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
from .optim import adam
2121

2222
# This is a signal for integrations with transformers/diffusers.
23-
# Eventually, we will remove this and check based on release version.
23+
# Eventually we may remove this but it is currently required for compatibility.
2424
features = {"multi-backend"}
2525
supported_torch_devices = {
26-
"cuda",
2726
"cpu",
28-
# "mps",
29-
# "xpu",
30-
# "hpu",
31-
# "npu",
27+
"cuda", # NVIDIA/AMD GPU
28+
"xpu", # Intel GPU
29+
"hpu", # Gaudi
30+
"npu", # Ascend NPU
31+
"mps", # Apple Silicon
3232
}
3333

3434
if torch.cuda.is_available():

bitsandbytes/nn/modules.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,15 @@ def _quantize(self, device):
306306
self.bnb_quantized = True
307307
return self
308308

309+
def cpu(self):
310+
return self.to(device="cpu")
311+
309312
def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
310313
return self.to(device="cuda" if device is None else device, non_blocking=non_blocking)
311314

315+
def xpu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
316+
return self.to(device="xpu" if device is None else device, non_blocking=non_blocking)
317+
312318
@overload
313319
def to(
314320
self: T,
@@ -326,7 +332,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
326332
def to(self, *args, **kwargs):
327333
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
328334

329-
if device is not None and device.type == "cuda" and not self.bnb_quantized:
335+
if device is not None and not self.bnb_quantized:
330336
return self._quantize(device)
331337
else:
332338
if self.quant_state is not None:

0 commit comments

Comments
 (0)