You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
> The `dtype` in which a model is loaded can be specified with the `torch_dtype` argument when loading models with `from_pretrained`. On-the-fly dtype conversion can also be done using the Pytorch provided `.to()` method. An important distinction to keep in mind is that the latter converts all weights to the specified dtype, while the former takes into account a special model attribute (`_keep_in_fp32_modules`) when loading the weights. This is important in cases where some layers in the model must remain in FP32 precision for numerical stability and best generation quality. An example can be found [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374).
246
+
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
0 commit comments