Skip to content

Commit b8d1c26

Browse files
Linear8bitLt: support device movement after forward() (#1769)
1 parent 42e8abc commit b8d1c26

File tree

2 files changed

+68
-7
lines changed

2 files changed

+68
-7
lines changed

bitsandbytes/nn/modules.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -679,19 +679,27 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
679679
def to(self, *args, **kwargs):
680680
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
681681

682-
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
683-
if device.type != "cpu" or self.data.dtype != torch.int8:
684-
return self._quantize(device)
685-
elif self.data.dtype == torch.int8 and device.type == "cpu":
686-
self.CB = self.data
682+
is_quantized = self.data.dtype == torch.int8
687683

684+
if not is_quantized and device is not None and device.type != "meta" and self.data.device.type == "cpu":
685+
# We're moving from a CPU device to a non-meta device.
686+
# In this circumstance, we want to quantize if we haven't already.
687+
return self._quantize(device)
688+
689+
# Create a new parameter on the target device.
688690
new_param = Int8Params(
689691
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
690692
requires_grad=self.requires_grad,
691693
has_fp16_weights=self.has_fp16_weights,
692694
)
693-
new_param.CB = self.CB
694-
new_param.SCB = self.SCB
695+
696+
# If we had already quantized, move the statistics appropriately.
697+
if is_quantized and device is not None:
698+
if self.CB is not None:
699+
new_param.CB = new_param.data
700+
701+
if self.SCB is not None:
702+
new_param.SCB = self.SCB.to(device)
695703

696704
return new_param
697705

@@ -1037,6 +1045,21 @@ def init_8bit_state(self):
10371045
self.weight.CB = None
10381046
self.weight.SCB = None
10391047

1048+
def to(self, *args, **kwargs):
1049+
# Call the parent to() method to handle standard parameter/buffer movement
1050+
result = super().to(*args, **kwargs)
1051+
1052+
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
1053+
1054+
# Handle state tensors if needed.
1055+
if device is not None:
1056+
if result.state.CB is not None:
1057+
result.state.CB = result.state.CB.to(device)
1058+
if result.state.SCB is not None:
1059+
result.state.SCB = result.state.SCB.to(device)
1060+
1061+
return result
1062+
10401063
def forward(self, x: torch.Tensor):
10411064
self.state.is_training = self.training
10421065
if self.weight.CB is not None:

tests/test_linear8bitlt.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,41 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
293293
grad_compiled = x.grad.clone()
294294

295295
torch.testing.assert_close(grad_compiled, grad_ref)
296+
297+
298+
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
299+
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
300+
def test_linear8bitlt_device_movement(device):
301+
"""Test moving a Linear8bitLt layer between CPU and an accelerator device."""
302+
303+
# Create a Linear8bitLt layer on CPU
304+
layer = bnb.nn.Linear8bitLt(32, 128, bias=False, has_fp16_weights=False)
305+
torch.nn.init.xavier_uniform_(layer.weight)
306+
307+
# Create a sample input.
308+
x = torch.randn(4, 32, dtype=torch.float16, device="cpu")
309+
310+
# Move to the device. This should quantize the weights.
311+
layer = layer.to(device)
312+
assert layer.weight.data.dtype == torch.int8
313+
314+
# Call the layer on the accelerator device.
315+
out_accelerator = layer(x.to(device))
316+
317+
# Move back to CPU and call again.
318+
layer = layer.to("cpu")
319+
out_cpu = layer(x)
320+
321+
# Move back to the accelerator device and call again.
322+
layer = layer.to(device)
323+
out_accelerator_2 = layer(x.to(device))
324+
325+
# Move back to the CPU and call one last time.
326+
layer = layer.to("cpu")
327+
out_cpu_2 = layer(x)
328+
329+
# CPU outputs should match both times.
330+
torch.testing.assert_close(out_cpu_2, out_cpu, rtol=1e-8, atol=1e-8)
331+
332+
# Accelerator outputs should match both times.
333+
torch.testing.assert_close(out_accelerator_2, out_accelerator, rtol=1e-8, atol=1e-8)

0 commit comments

Comments
 (0)