Skip to content

Commit 12f5c59

Browse files
committed
float8 check.
1 parent 71316a6 commit 12f5c59

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,15 @@ def load_model_dict_into_meta(
185185
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
186186
empty_state_dict = model.state_dict()
187187
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
188-
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
189188

190189
for param_name, param in state_dict.items():
191190
if param_name not in empty_state_dict:
192191
continue
193192

194-
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
193+
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
195194
# in int/uint/bool and not cast them.
196-
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
197-
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
195+
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
196+
if torch.is_floating_point(param):
198197
if (
199198
keep_in_fp32_modules is not None
200199
and any(

0 commit comments

Comments
 (0)