Skip to content

Commit 478c8d0

Browse files
committed
fix device_map in test
1 parent c0fc91b commit 478c8d0

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/models/test_modeling_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1764,11 +1764,12 @@ def test_passing_non_dict_device_map_works(self, device_map):
17641764

17651765
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
17661766
@require_torch_gpu
1767-
def test_passing_dict_device_map_works(self, name, device_map):
1767+
def test_passing_dict_device_map_works(self, name, device):
17681768
# There are other valid dict-based `device_map` values too. It's best to refer to
17691769
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
17701770
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17711771
model = self.model_class(**init_dict).eval()
1772+
device_map = {name: device}
17721773
with tempfile.TemporaryDirectory() as tmpdir:
17731774
model.save_pretrained(tmpdir)
17741775
loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)

0 commit comments

Comments
 (0)