Skip to content

Commit f8d4a1e

Browse files
johannaSommerDN6
andauthored
fix: remove torch_dtype="auto" option from docstrings (#11513)
Co-authored-by: Dhruv Nair <[email protected]>
1 parent ddd0cfb commit f8d4a1e

File tree

9 files changed

+24
-34
lines changed

9 files changed

+24
-34
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
187187
original_config (`str`, *optional*):
188188
Dict or path to a yaml file containing the configuration for the model in its original format.
189189
If a dict is provided, it will be used to initialize the model configuration.
190-
torch_dtype (`str` or `torch.dtype`, *optional*):
191-
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
192-
dtype is automatically derived from the model's weights.
190+
torch_dtype (`torch.dtype`, *optional*):
191+
Override the default `torch.dtype` and load the model with another dtype.
193192
force_download (`bool`, *optional*, defaults to `False`):
194193
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
195194
cached versions if they exist.

src/diffusers/models/adapter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,8 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]
161161
pretrained_model_path (`os.PathLike`):
162162
A path to a *directory* containing model weights saved using
163163
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
164-
torch_dtype (`str` or `torch.dtype`, *optional*):
165-
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
166-
will be automatically derived from the model's weights.
164+
torch_dtype (`torch.dtype`, *optional*):
165+
Override the default `torch.dtype` and load the model under this dtype.
167166
output_loading_info(`bool`, *optional*, defaults to `False`):
168167
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
169168
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):

src/diffusers/models/auto_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLi
5252
cache_dir (`Union[str, os.PathLike]`, *optional*):
5353
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
5454
is not used.
55-
torch_dtype (`str` or `torch.dtype`, *optional*):
56-
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
57-
dtype is automatically derived from the model's weights.
55+
torch_dtype (`torch.dtype`, *optional*):
56+
Override the default `torch.dtype` and load the model with another dtype.
5857
force_download (`bool`, *optional*, defaults to `False`):
5958
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
6059
cached versions if they exist.

src/diffusers/models/controlnets/multicontrolnet.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,8 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]
130130
A path to a *directory* containing model weights saved using
131131
[`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
132132
`./my_model_directory/controlnet`.
133-
torch_dtype (`str` or `torch.dtype`, *optional*):
134-
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
135-
will be automatically derived from the model's weights.
133+
torch_dtype (`torch.dtype`, *optional*):
134+
Override the default `torch.dtype` and load the model under this dtype.
136135
output_loading_info(`bool`, *optional*, defaults to `False`):
137136
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
138137
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):

src/diffusers/models/controlnets/multicontrolnet_union.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,8 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]
143143
A path to a *directory* containing model weights saved using
144144
[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
145145
`./my_model_directory/controlnet`.
146-
torch_dtype (`str` or `torch.dtype`, *optional*):
147-
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
148-
will be automatically derived from the model's weights.
146+
torch_dtype (`torch.dtype`, *optional*):
147+
Override the default `torch.dtype` and load the model under this dtype.
149148
output_loading_info(`bool`, *optional*, defaults to `False`):
150149
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
151150
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -787,9 +787,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
787787
cache_dir (`Union[str, os.PathLike]`, *optional*):
788788
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
789789
is not used.
790-
torch_dtype (`str` or `torch.dtype`, *optional*):
791-
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
792-
dtype is automatically derived from the model's weights.
790+
torch_dtype (`torch.dtype`, *optional*):
791+
Override the default `torch.dtype` and load the model with another dtype.
793792
force_download (`bool`, *optional*, defaults to `False`):
794793
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
795794
cached versions if they exist.

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
322322
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
323323
saved using
324324
[`~DiffusionPipeline.save_pretrained`].
325-
torch_dtype (`str` or `torch.dtype`, *optional*):
326-
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
327-
dtype is automatically derived from the model's weights.
325+
torch_dtype (`torch.dtype`, *optional*):
326+
Override the default `torch.dtype` and load the model with another dtype.
328327
force_download (`bool`, *optional*, defaults to `False`):
329328
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
330329
cached versions if they exist.
@@ -619,8 +618,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
619618
saved using
620619
[`~DiffusionPipeline.save_pretrained`].
621620
torch_dtype (`str` or `torch.dtype`, *optional*):
622-
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
623-
dtype is automatically derived from the model's weights.
621+
Override the default `torch.dtype` and load the model with another dtype.
624622
force_download (`bool`, *optional*, defaults to `False`):
625623
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
626624
cached versions if they exist.
@@ -930,8 +928,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
930928
saved using
931929
[`~DiffusionPipeline.save_pretrained`].
932930
torch_dtype (`str` or `torch.dtype`, *optional*):
933-
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
934-
dtype is automatically derived from the model's weights.
931+
Override the default `torch.dtype` and load the model with another dtype.
935932
force_download (`bool`, *optional*, defaults to `False`):
936933
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
937934
cached versions if they exist.

src/diffusers/pipelines/pipeline_flax_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
248248
pretrained pipeline hosted on the Hub.
249249
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
250250
using [`~FlaxDiffusionPipeline.save_pretrained`].
251-
dtype (`str` or `jnp.dtype`, *optional*):
252-
Override the default `jnp.dtype` and load the model under this dtype. If `"auto"`, the dtype is
253-
automatically derived from the model's weights.
251+
dtype (`jnp.dtype`, *optional*):
252+
Override the default `jnp.dtype` and load the model under this dtype.
254253
force_download (`bool`, *optional*, defaults to `False`):
255254
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
256255
cached versions if they exist.

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -573,12 +573,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
573573
saved using
574574
[`~DiffusionPipeline.save_pretrained`].
575575
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
576-
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
577-
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
578-
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
579-
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
580-
unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
581-
torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
576+
torch_dtype (`torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
577+
Override the default `torch.dtype` and load the model with another dtype. To load submodels with
578+
different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`).
579+
Set the default dtype for unspecified components with `default` (for example `{'transformer':
580+
torch.bfloat16, 'default': torch.float16}`). If a component is not specified and no default is set,
581+
`torch.float32` is used.
582582
custom_pipeline (`str`, *optional*):
583583
584584
<Tip warning={true}>

0 commit comments

Comments
 (0)