Skip to content

Commit 4e17f91

Browse files
remove redundant dtype check in flashpack
1 parent 73e5897 commit 4e17f91

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,32 +1223,29 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12231223
torch.set_default_dtype(dtype_orig)
12241224

12251225
# flashpack requires a single dtype across all parameters
1226-
param_dtypes = {p.dtype for p in model.parameters()}
1227-
if len(param_dtypes) > 1:
1228-
pass
1229-
else:
1230-
try:
1231-
assign_from_file(model, flashpack_file)
1232-
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
12331226

1234-
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
1235-
model = model.to(torch_dtype)
1227+
try:
1228+
assign_from_file(model, flashpack_file)
1229+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1230+
1231+
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
1232+
model = model.to(torch_dtype)
12361233

1237-
model.eval()
1234+
model.eval()
12381235

1239-
if output_loading_info:
1240-
loading_info = {
1241-
"missing_keys": [],
1242-
"unexpected_keys": [],
1243-
"mismatched_keys": [],
1244-
"error_msgs": [],
1245-
}
1246-
return model, loading_info
1236+
if output_loading_info:
1237+
loading_info = {
1238+
"missing_keys": [],
1239+
"unexpected_keys": [],
1240+
"mismatched_keys": [],
1241+
"error_msgs": [],
1242+
}
1243+
return model, loading_info
12471244

1248-
return model
1245+
return model
12491246

1250-
except Exception:
1251-
pass
1247+
except Exception:
1248+
pass
12521249
# in the case it is sharded, we have already the index
12531250
if is_sharded:
12541251
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(

0 commit comments

Comments
 (0)