Skip to content

Commit d3658c5

Browse files
authored
Fix xpu to cpu (#1570)
* fix xpu to cpu Signed-off-by: jiqing-feng <[email protected]> * fix xpu cpu data device Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent 8fe6325 commit d3658c5

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

bitsandbytes/nn/modules.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -694,32 +694,30 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
694694
def to(self, *args, **kwargs):
695695
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
696696

697-
if device is not None and device.type in ("cuda", "xpu", "cpu"):
697+
if device is not None:
698698
if device.type == "cuda" and self.data.device.type == "cpu":
699699
return self.cuda(device)
700700
elif device.type == "cpu":
701701
if self.data.dtype == torch.int8:
702702
self.CB = self.data
703-
return self
704703
else:
705704
return self.cpu()
706705
elif device.type == "xpu":
707706
if self.data.dtype == torch.int8:
708-
self.data = self.data.contiguous().xpu(device)
707+
self.data = self.data.contiguous()
709708
self.CB = self.data
710-
return self
711-
else:
709+
if self.data.device.type == "cpu":
712710
return self.xpu(device)
713-
else:
714-
new_param = Int8Params(
715-
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
716-
requires_grad=self.requires_grad,
717-
has_fp16_weights=self.has_fp16_weights,
718-
)
719-
new_param.CB = self.CB
720-
new_param.SCB = self.SCB
721711

722-
return new_param
712+
new_param = Int8Params(
713+
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
714+
requires_grad=self.requires_grad,
715+
has_fp16_weights=self.has_fp16_weights,
716+
)
717+
new_param.CB = self.CB
718+
new_param.SCB = self.SCB
719+
720+
return new_param
723721

724722

725723
def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):

0 commit comments

Comments
 (0)