Skip to content

Commit 1e27a22

Browse files
committed
fix 8bit int8 param device
Signed-off-by: jiqing-feng <[email protected]>
1 parent aa3b245 commit 1e27a22

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def __deepcopy__(self, memo):
644644

645645
def cpu(self):
646646
# we store the 8-bit rows-major weight
647-
B = self.data.contiguous().bfloat16().cpu()
647+
B = self.data.contiguous().to(torch.bfloat16).cpu()
648648
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
649649
if CBt is not None:
650650
del CBt
@@ -657,7 +657,7 @@ def cpu(self):
657657

658658
def xpu(self):
659659
# we store the 8-bit rows-major weight
660-
B = self.data.contiguous().float16().xpu()
660+
B = self.data.contiguous().to(torch.float16).xpu()
661661
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
662662
if CBt is not None:
663663
del CBt

0 commit comments

Comments
 (0)