Skip to content

Commit 08d429b

Browse files
committed
more
1 parent e9ccc73 commit 08d429b

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
986986
)
987987
else:
988988
device_map = {"": device_map}
989+
# {"": device} case.
990+
elif isinstance(device_map, dict) and len(dict) == 1:
991+
device_value = list(device_map.values())[0]
992+
if isinstance(device_value, str):
993+
try:
994+
device_map = {"": torch.device(device_value)}
995+
except RuntimeError:
996+
raise ValueError(f"Invalid value ({device_value}) specified in the {device_map=}.")
989997

990998
if device_map is not None:
991999
if low_cpu_mem_usage is None:

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,6 +1108,18 @@ def test_passing_non_dict_device_map_works(self, device_map):
11081108
output = loaded_model(**inputs_dict)
11091109
assert output.sample.shape == (4, 4, 16, 16)
11101110

1111+
@parameterized.expand([{"": "cuda"}, {"": torch.device("cuda")}, {"": "cpu"}, {"": torch.device("cpu")}])
1112+
@require_torch_gpu
1113+
def test_passing_dict_device_map_works(self, device_map):
1114+
# There are other valid dict-based `device_map` values too. It's best to refer to
1115+
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
1116+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1117+
loaded_model = self.model_class.from_pretrained(
1118+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
1119+
)
1120+
output = loaded_model(**inputs_dict)
1121+
assert output.sample.shape == (4, 4, 16, 16)
1122+
11111123
@require_peft_backend
11121124
def test_load_attn_procs_raise_warning(self):
11131125
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)