Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)


if device is not None and device.type != "meta" and self.data.device.type == "cpu":
if device.type != "cpu" or self.data.dtype != torch.int8:
return self._quantize(device)
Expand All @@ -690,8 +691,8 @@ def to(self, *args, **kwargs):
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
)
new_param.CB = self.CB
new_param.SCB = self.SCB
new_param.CB = self.CB.to(device=device) if self.CB != None else self.CB
new_param.SCB = self.SCB.to(device=device) if self.SCB != None else self.SCB

return new_param

Expand Down