@@ -814,11 +814,31 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
814814                Mirror source to resolve accessibility issues if you're downloading a model in China. We do not 
815815                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more 
816816                information. 
817-             device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): 
817+             device_map (`Union[int,  str, torch.Device] ` or `Dict[str, Union[int, str, torch.device]]`, *optional*): 
818818                A map that specifies where each submodule should go. It doesn't need to be defined for each 
819819                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the 
820820                same device. Defaults to `None`, meaning that the model will be loaded on CPU. 
821821
822+                 Examples: 
823+ 
824+                 ```py 
825+                 >>> from diffusers import AutoModel 
826+                 >>> import torch 
827+ 
828+                 >>> # This works. 
829+                 >>> model = AutoModel.from_pretrained( 
830+                 ...     "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda" 
831+                 ... ) 
832+                 >>> # This also works (integer accelerator device ID). 
833+                 >>> model = AutoModel.from_pretrained( 
834+                 ...     "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0 
835+                 ... ) 
836+                 >>> # Specify a supported offloading strategy like "auto" also works. 
837+                 >>> model = AutoModel.from_pretrained( 
838+                 ...     "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto" 
839+                 ... ) 
840+                 ``` 
841+ 
822842                Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For 
823843                more information about each option see [designing a device 
824844                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). You can 
@@ -1390,7 +1410,7 @@ def _load_pretrained_model(
13901410        low_cpu_mem_usage : bool  =  True ,
13911411        dtype : Optional [Union [str , torch .dtype ]] =  None ,
13921412        keep_in_fp32_modules : Optional [List [str ]] =  None ,
1393-         device_map : Dict [str , Union [int , str , torch .device ]] =  None ,
1413+         device_map : Dict [Union [ str ,  int ,  torch . device ] , Union [int , str , torch .device ]] =  None ,
13941414        offload_state_dict : Optional [bool ] =  None ,
13951415        offload_folder : Optional [Union [str , os .PathLike ]] =  None ,
13961416        dduf_entries : Optional [Dict [str , DDUFEntry ]] =  None ,
0 commit comments