@@ -816,14 +816,43 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
816816 Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
817817 guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
818818 information.
819- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
819+ device_map (`Union[int, str, torch.device] ` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
820820 A map that specifies where each submodule should go. It doesn't need to be defined for each
821821 parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
822822 same device. Defaults to `None`, meaning that the model will be loaded on CPU.
823823
824+ Examples:
825+
826+ ```py
827+ >>> from diffusers import AutoModel
828+ >>> import torch
829+
830+ >>> # This works.
831+ >>> model = AutoModel.from_pretrained(
832+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
833+ ... )
834+ >>> # This also works (integer accelerator device ID).
835+ >>> model = AutoModel.from_pretrained(
836+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
837+ ... )
838+ >>> # Specifying a supported offloading strategy like "auto" also works.
839+ >>> model = AutoModel.from_pretrained(
840+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
841+ ... )
842+ >>> # Specifying a dictionary as `device_map` also works.
843+ >>> model = AutoModel.from_pretrained(
844+ ... "stabilityai/stable-diffusion-xl-base-1.0",
845+ ... subfolder="unet",
846+ ... device_map={"": torch.device("cuda")},
847+ ... )
848+ ```
849+
824850 Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
825851 more information about each option see [designing a device
826- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
852+ map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You
853+ can also refer to the [Diffusers-specific
854+ documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
855+ for more concrete examples.
827856 max_memory (`Dict`, *optional*):
828857 A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
829858 each GPU and the available CPU RAM if unset.
@@ -1389,7 +1418,7 @@ def _load_pretrained_model(
13891418 low_cpu_mem_usage : bool = True ,
13901419 dtype : Optional [Union [str , torch .dtype ]] = None ,
13911420 keep_in_fp32_modules : Optional [List [str ]] = None ,
1392- device_map : Dict [str , Union [int , str , torch .device ]] = None ,
1421+ device_map : Union [ str , int , torch . device , Dict [str , Union [int , str , torch .device ] ]] = None ,
13931422 offload_state_dict : Optional [bool ] = None ,
13941423 offload_folder : Optional [Union [str , os .PathLike ]] = None ,
13951424 dduf_entries : Optional [Dict [str , DDUFEntry ]] = None ,
0 commit comments