Skip to content

Commit aa3b245

Browse files
committed
fix 8bit int8 param device
Signed-off-by: jiqing-feng <[email protected]>
1 parent 9da03f1 commit aa3b245

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

bitsandbytes/nn/modules.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def cpu(self):
657657

658658
def xpu(self):
659659
# we store the 8-bit rows-major weight
660-
B = self.data.contiguous().bfloat16().xpu()
660+
B = self.data.contiguous().float16().xpu()
661661
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
662662
if CBt is not None:
663663
del CBt
@@ -695,8 +695,9 @@ def to(self, *args, **kwargs):
695695
return self.cpu()
696696
elif device.type == "xpu":
697697
if self.data.dtype == torch.int8:
698+
self.data = self.data.contiguous().xpu()
698699
self.CB = self.data
699-
return super().xpu(device)
700+
return self
700701
else:
701702
return self.xpu()
702703
else:

0 commit comments

Comments
 (0)