Skip to content

Commit 27ed776

Browse files
rezaqorbaniReza Qorbani
andauthored
Fix hf_device_map device comparison in prepare_model (#3895)
Co-authored-by: Reza Qorbani <[email protected]>
1 parent 4bf9964 commit 27ed776

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

src/accelerate/accelerator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1811,10 +1811,11 @@ def prepare_model(
18111811
else:
18121812
current_device_index = current_device
18131813

1814+
current_device_index = int(current_device_index) if current_device_index is not None else None
18141815
if self.device.type == "cpu" and is_bitsandbytes_multi_backend_available():
18151816
# bnb with multi-backend supports CPU which don't need to check index.
18161817
pass
1817-
elif torch.device(current_device_index) != self.device:
1818+
elif torch.device(self.device.type, current_device_index) != self.device:
18181819
# if on the first device (GPU 0) we don't care
18191820
if (self.device.index is not None) or (current_device_index != 0):
18201821
raise ValueError(

tests/test_accelerator.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from accelerate.test_utils.testing import (
4646
AccelerateTestCase,
47+
assert_exception,
4748
require_cuda,
4849
require_non_torch_xla,
4950
require_torchdata_stateful_dataloader,
@@ -861,3 +862,27 @@ def forward(self, x):
861862
# weight is on the meta device, we need a `value` to put in on 0
862863
x = torch.randn(1, 2)
863864
my_model(x)
865+
866+
@require_non_torch_xla
867+
def test_prepare_model_8bit_cpu_offload_raises_valueerror_not_typeerror(self):
868+
class ModelForTest(torch.nn.Module):
869+
def __init__(self):
870+
super().__init__()
871+
self.l = torch.nn.Linear(2, 2)
872+
873+
def forward(self, x):
874+
return self.l(x)
875+
876+
accelerator = Accelerator()
877+
model = ModelForTest()
878+
879+
# Trigger the 8-bit/4-bit + hf_device_map code path.
880+
model.is_loaded_in_8bit = True
881+
model.hf_device_map = {"": "cpu"}
882+
883+
with (
884+
patch("accelerate.accelerator.is_bitsandbytes_multi_backend_available", return_value=False),
885+
patch("accelerate.accelerator.is_xpu_available", return_value=False),
886+
):
887+
with assert_exception(ValueError, "CPU or disk offload"):
888+
accelerator.prepare_model(model)

0 commit comments

Comments
 (0)