From d6503f698f1c4c596ee53d8159189ff58959b0b6 Mon Sep 17 00:00:00 2001 From: hlky Date: Sun, 13 Apr 2025 09:43:11 +0100 Subject: [PATCH] [single file] Detect CLIP `max_position_embeddings` --- src/diffusers/loaders/single_file_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 556f03f7992f..e3c3d5e56ae0 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -1587,6 +1587,7 @@ def create_diffusers_clip_model_from_ldm( model = cls(model_config) position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1] + position_embedding_size = model.text_model.embeddings.position_embedding.weight.shape[0] if is_clip_model(checkpoint): diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint) @@ -1628,6 +1629,13 @@ def create_diffusers_clip_model_from_ldm( else: raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.") + num_position_embeddings = diffusers_format_checkpoint["text_model.embeddings.position_embedding.weight"].shape[0] + if num_position_embeddings != position_embedding_size: + logger.warning(f"Overriding config with detected {num_position_embeddings=}.") + model_config.max_position_embeddings = num_position_embeddings + with ctx(): + model = cls(model_config) + if is_accelerate_available(): load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype) else: