Skip to content

Commit 4acbf5c

Browse files
a-r-r-o-wstevhliu
andauthored
Update docs/source/en/using-diffusers/schedulers.md
Co-authored-by: Steven Liu <[email protected]>
1 parent 8fbd332 commit 4acbf5c

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

docs/source/en/using-diffusers/schedulers.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,5 +243,14 @@ unet = UNet2DConditionModel.from_pretrained(
243243
unet.save_pretrained("./local-unet", variant="non_ema")
244244
```
245245

246-
> [!TIP]
247-
> The `dtype` in which a model is loaded can be specified with the `torch_dtype` argument when loading models with `from_pretrained`. On-the-fly dtype conversion can also be done using the Pytorch provided `.to()` method. An important distinction to keep in mind is that the latter converts all weights to the specified dtype, while the former takes into account a special model attribute (`_keep_in_fp32_modules`) when loading the weights. This is important in cases where some layers in the model must remain in FP32 precision for numerical stability and best generation quality. An example can be found [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374).
246+
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
247+
248+
```py
249+
from diffusers import AutoModel
250+
251+
unet = AutoModel.from_pretrained(
252+
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
253+
)
254+
255+
256+
You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).

0 commit comments

Comments
 (0)