Skip to content

Commit 0df7010

Browse files
committed
more fix from suggestions
1 parent aedf6af commit 0df7010

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,13 +1089,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
10891089
state_dict = load_state_dict(
10901090
resolved_archive_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries
10911091
)
1092+
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
1093+
model._fix_state_dict_keys_on_load(state_dict)
10921094

10931095
if is_sharded:
10941096
loaded_keys = sharded_metadata["all_checkpoint_keys"]
10951097
else:
10961098
loaded_keys = list(state_dict.keys())
1097-
# TODO: hacky solution
1098-
loaded_keys = list(model._fix_state_dict_keys_on_load({key: "" for key in loaded_keys}))
10991099

11001100
if hf_quantizer is not None:
11011101
hf_quantizer.preprocess_model(
@@ -1305,7 +1305,6 @@ def _load_pretrained_model(
13051305

13061306
for shard_file in resolved_archive_file:
13071307
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
1308-
model._fix_state_dict_keys_on_load(state_dict)
13091308

13101309
def _find_mismatched_keys(
13111310
state_dict,
@@ -1578,7 +1577,8 @@ def _fix_state_dict_keys_on_load(self, state_dict: OrderedDict) -> None:
15781577
"""
15791578
This function fix the state dict of the model to take into account some changes that were made in the model
15801579
architecture:
1581-
- depretated attention blocks
1580+
- deprecated attention blocks (happened before we introduced sharded checkpoint,
1581+
so this is why we apply this method only when loading non sharded checkpoints for now)
15821582
"""
15831583
deprecated_attention_block_paths = []
15841584

0 commit comments

Comments
 (0)