Skip to content

Commit a0a95fd

Browse files
authored
add device index (#1489)
1 parent 307fbd5 commit a0a95fd

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

bitsandbytes/nn/modules.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -660,9 +660,9 @@ def cpu(self):
660660
self.SCB = SCB
661661
return self
662662

663-
def xpu(self):
663+
def xpu(self, device):
664664
# we store the 8-bit rows-major weight
665-
B = self.data.contiguous().to(torch.float16).xpu()
665+
B = self.data.contiguous().to(torch.float16).xpu(device)
666666
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
667667
if CBt is not None:
668668
del CBt
@@ -700,11 +700,11 @@ def to(self, *args, **kwargs):
700700
return self.cpu()
701701
elif device.type == "xpu":
702702
if self.data.dtype == torch.int8:
703-
self.data = self.data.contiguous().xpu()
703+
self.data = self.data.contiguous().xpu(device)
704704
self.CB = self.data
705705
return self
706706
else:
707-
return self.xpu()
707+
return self.xpu(device)
708708
else:
709709
new_param = Int8Params(
710710
super().to(device=device, dtype=dtype, non_blocking=non_blocking),

0 commit comments

Comments
 (0)