Skip to content

Commit 5e35ac5

Browse files
committed
docs
1 parent a4dd7fd commit 5e35ac5

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,11 +814,31 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
814814
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
815815
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
816816
information.
817-
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
817+
device_map (`Union[int, str, torch.Device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
818818
A map that specifies where each submodule should go. It doesn't need to be defined for each
819819
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
820820
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
821821
822+
Examples:
823+
824+
```py
825+
>>> from diffusers import AutoModel
826+
>>> import torch
827+
828+
>>> # This works.
829+
>>> model = AutoModel.from_pretrained(
830+
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
831+
... )
832+
>>> # This also works (integer accelerator device ID).
833+
>>> model = AutoModel.from_pretrained(
834+
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
835+
... )
836+
>>> # Specify a supported offloading strategy like "auto" also works.
837+
>>> model = AutoModel.from_pretrained(
838+
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
839+
... )
840+
```
841+
822842
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
823843
more information about each option see [designing a device
824844
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). You can
@@ -1390,7 +1410,7 @@ def _load_pretrained_model(
13901410
low_cpu_mem_usage: bool = True,
13911411
dtype: Optional[Union[str, torch.dtype]] = None,
13921412
keep_in_fp32_modules: Optional[List[str]] = None,
1393-
device_map: Dict[str, Union[int, str, torch.device]] = None,
1413+
device_map: Dict[Union[str, int, torch.device], Union[int, str, torch.device]] = None,
13941414
offload_state_dict: Optional[bool] = None,
13951415
offload_folder: Optional[Union[str, os.PathLike]] = None,
13961416
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,22 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
10861086
assert loaded_model
10871087
assert new_output.sample.shape == (4, 4, 16, 16)
10881088

1089+
def test_wrong_device_map_raises_error(self):
1090+
with self.assertRaises(ValueError) as err_ctx:
1091+
_ = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy-subfolder", device_map=-1)
1092+
msg_substring = "You can't pass device_map as a negative int"
1093+
assert msg_substring in str(err_ctx.exception)
1094+
1095+
@require_torch_gpu
1096+
@parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")])
1097+
def test_passing_non_dict_device_map_works(self, device_map):
1098+
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1099+
loaded_model = self.model_class.from_pretrained(
1100+
"hf-internal-testing/unet2d-sharded-dummy-subfolder", device_map=device_map
1101+
)
1102+
output = loaded_model(**inputs_dict)
1103+
assert output.sample.shape == (4, 4, 16, 16)
1104+
10891105
@require_peft_backend
10901106
def test_load_attn_procs_raise_warning(self):
10911107
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)