Skip to content

Commit 2753abe

Browse files
committed
fp8 dtype
1 parent b48fedc commit 2753abe

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1143,8 +1143,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11431143

11441144
# set dtype to instantiate the model under:
11451145
# 1. If torch_dtype is not None, we use that dtype
1146+
# 2. If torch_dtype is float8, we don't use _set_default_torch_dtype and we downcast after loading the model
11461147
dtype_orig = None
1147-
if torch_dtype is not None:
1148+
if torch_dtype is not None and not torch_dtype == getattr(torch, "float8_e4m3fn", None):
11481149
if not isinstance(torch_dtype, torch.dtype):
11491150
raise ValueError(
11501151
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
@@ -1231,6 +1232,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12311232
hf_quantizer.postprocess_model(model)
12321233
model.hf_quantizer = hf_quantizer
12331234

1235+
if (
1236+
torch_dtype is not None
1237+
and torch_dtype == getattr(torch, "float8_e4m3fn", None)
1238+
and hf_quantizer is None
1239+
and not use_keep_in_fp32_modules
1240+
):
1241+
model = model.to(torch_dtype)
1242+
12341243
if hf_quantizer is not None:
12351244
# We also make sure to purge `_pre_quantization_dtype` when we serialize
12361245
# the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable.

0 commit comments

Comments
 (0)