@@ -814,14 +814,43 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
814814 Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
815815 guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
816816 information.
817- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
817+ device_map (`Union[int, str, torch.device] ` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
818818 A map that specifies where each submodule should go. It doesn't need to be defined for each
819819 parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
820820 same device. Defaults to `None`, meaning that the model will be loaded on CPU.
821821
822+ Examples:
823+
824+ ```py
825+ >>> from diffusers import AutoModel
826+ >>> import torch
827+
828+ >>> # This works.
829+ >>> model = AutoModel.from_pretrained(
830+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
831+ ... )
832+ >>> # This also works (integer accelerator device ID).
833+ >>> model = AutoModel.from_pretrained(
834+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
835+ ... )
836+ >>> # Specifying a supported offloading strategy like "auto" also works.
837+ >>> model = AutoModel.from_pretrained(
838+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
839+ ... )
840+ >>> # Specifying a dictionary as `device_map` also works.
841+ >>> model = AutoModel.from_pretrained(
842+ ... "stabilityai/stable-diffusion-xl-base-1.0",
843+ ... subfolder="unet",
844+ ... device_map={"": torch.device("cuda")},
845+ ... )
846+ ```
847+
822848 Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
823849 more information about each option see [designing a device
824- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
850+ map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You
851+ can also refer to the [Diffusers-specific
852+ documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
853+ for more concrete examples.
825854 max_memory (`Dict`, *optional*):
826855 A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
827856 each GPU and the available CPU RAM if unset.
@@ -1387,7 +1416,7 @@ def _load_pretrained_model(
13871416 low_cpu_mem_usage : bool = True ,
13881417 dtype : Optional [Union [str , torch .dtype ]] = None ,
13891418 keep_in_fp32_modules : Optional [List [str ]] = None ,
1390- device_map : Dict [str , Union [int , str , torch .device ]] = None ,
1419+ device_map : Union [ str , int , torch . device , Dict [str , Union [int , str , torch .device ] ]] = None ,
13911420 offload_state_dict : Optional [bool ] = None ,
13921421 offload_folder : Optional [Union [str , os .PathLike ]] = None ,
13931422 dduf_entries : Optional [Dict [str , DDUFEntry ]] = None ,
0 commit comments