Skip to content

Commit 81bb48a

Browse files
committed
handle dtype casting.
1 parent de6394a commit 81bb48a

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
963963
raise ValueError(
964964
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
965965
)
966-
elif torch_dtype is not None and hf_quantizer is None:
966+
# When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
967+
# completely lose the effectivity of `use_keep_in_fp32_modules`. `transformers` does
968+
# a global dtype setting (see: https://github.com/huggingface/transformers/blob/fa3f2db5c7405a742fcb8f686d3754f70db00977/src/transformers/modeling_utils.py#L4021),
969+
# but this would prevent us from doing things like https://github.com/huggingface/diffusers/pull/9177/.
970+
elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
967971
model = model.to(torch_dtype)
968972

969973
if hf_quantizer is not None:

0 commit comments

Comments
 (0)