Skip to content

Commit d6503f6

Browse files
committed
[single file] Detect CLIP max_position_embeddings
1 parent 97e0ef4 commit d6503f6

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/diffusers/loaders/single_file_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,7 @@ def create_diffusers_clip_model_from_ldm(
15871587
model = cls(model_config)
15881588

15891589
position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1]
1590+
position_embedding_size = model.text_model.embeddings.position_embedding.weight.shape[0]
15901591

15911592
if is_clip_model(checkpoint):
15921593
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
@@ -1628,6 +1629,13 @@ def create_diffusers_clip_model_from_ldm(
16281629
else:
16291630
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
16301631

1632+
num_position_embeddings = diffusers_format_checkpoint["text_model.embeddings.position_embedding.weight"].shape[0]
1633+
if num_position_embeddings != position_embedding_size:
1634+
logger.warning(f"Overriding config with detected {num_position_embeddings=}.")
1635+
model_config.max_position_embeddings = num_position_embeddings
1636+
with ctx():
1637+
model = cls(model_config)
1638+
16311639
if is_accelerate_available():
16321640
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
16331641
else:

0 commit comments

Comments
 (0)