Skip to content

Commit f69511e

Browse files
DN6yiyixuxusayakpaul
authored
[Single File Loading] Handle unexpected keys in CLIP models when accelerate isn't installed. (#8462)
* update * update * update * update * update --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent d2b10b1 commit f69511e

File tree

2 files changed

+19
-22
lines changed

2 files changed

+19
-22
lines changed

src/diffusers/loaders/single_file_model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,16 +276,18 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
276276

277277
if is_accelerate_available():
278278
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
279-
if model._keys_to_ignore_on_load_unexpected is not None:
280-
for pat in model._keys_to_ignore_on_load_unexpected:
281-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
282279

283-
if len(unexpected_keys) > 0:
284-
logger.warning(
285-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
286-
)
287280
else:
288-
model.load_state_dict(diffusers_format_checkpoint)
281+
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
282+
283+
if model._keys_to_ignore_on_load_unexpected is not None:
284+
for pat in model._keys_to_ignore_on_load_unexpected:
285+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
286+
287+
if len(unexpected_keys) > 0:
288+
logger.warning(
289+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
290+
)
289291

290292
if torch_dtype is not None:
291293
model.to(torch_dtype)

src/diffusers/loaders/single_file_utils.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,8 +1268,6 @@ def convert_open_clip_checkpoint(
12681268
else:
12691269
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
12701270

1271-
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
1272-
12731271
keys = list(checkpoint.keys())
12741272
keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
12751273

@@ -1318,9 +1316,6 @@ def convert_open_clip_checkpoint(
13181316
else:
13191317
text_model_dict[diffusers_key] = checkpoint.get(key)
13201318

1321-
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
1322-
text_model_dict.pop("text_model.embeddings.position_ids", None)
1323-
13241319
return text_model_dict
13251320

13261321

@@ -1414,17 +1409,17 @@ def create_diffusers_clip_model_from_ldm(
14141409

14151410
if is_accelerate_available():
14161411
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1417-
if model._keys_to_ignore_on_load_unexpected is not None:
1418-
for pat in model._keys_to_ignore_on_load_unexpected:
1419-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1412+
else:
1413+
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
14201414

1421-
if len(unexpected_keys) > 0:
1422-
logger.warning(
1423-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1424-
)
1415+
if model._keys_to_ignore_on_load_unexpected is not None:
1416+
for pat in model._keys_to_ignore_on_load_unexpected:
1417+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
14251418

1426-
else:
1427-
model.load_state_dict(diffusers_format_checkpoint)
1419+
if len(unexpected_keys) > 0:
1420+
logger.warning(
1421+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1422+
)
14281423

14291424
if torch_dtype is not None:
14301425
model.to(torch_dtype)

0 commit comments

Comments
 (0)