Skip to content

Commit 3bebf25

Browse files
committed
update
1 parent 08d429b commit 3bebf25

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
987987
else:
988988
device_map = {"": device_map}
989989
# {"": device} case.
990-
elif isinstance(device_map, dict) and len(dict) == 1:
990+
elif isinstance(device_map, dict) and len(device_map) == 1:
991991
device_value = list(device_map.values())[0]
992992
if isinstance(device_value, str):
993993
try:

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,14 +1108,14 @@ def test_passing_non_dict_device_map_works(self, device_map):
11081108
output = loaded_model(**inputs_dict)
11091109
assert output.sample.shape == (4, 4, 16, 16)
11101110

1111-
@parameterized.expand([{"": "cuda"}, {"": torch.device("cuda")}, {"": "cpu"}, {"": torch.device("cpu")}])
1111+
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
11121112
@require_torch_gpu
1113-
def test_passing_dict_device_map_works(self, device_map):
1113+
def test_passing_dict_device_map_works(self, name, device_map):
11141114
# There are other valid dict-based `device_map` values too. It's best to refer to
11151115
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
11161116
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
11171117
loaded_model = self.model_class.from_pretrained(
1118-
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
1118+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map={name: device_map}
11191119
)
11201120
output = loaded_model(**inputs_dict)
11211121
assert output.sample.shape == (4, 4, 16, 16)

0 commit comments

Comments
 (0)