Skip to content

Commit e745148

Browse files
committed
SD2.x projection dim hack no longer needed
huggingface/diffusers#10770
1 parent 67f80fd commit e745148

File tree

2 files changed

+33
-117
lines changed

2 files changed

+33
-117
lines changed

dgenerate/pipelinewrapper/pipelines.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2517,20 +2517,19 @@ def _handle_generic_pipeline_load_failure(e):
25172517
raise UnsupportedPipelineConfigError(
25182518
'Single file model loads do not support the subfolder option.')
25192519
try:
2520-
with _util._patch_sd21_clip_from_ldm():
2521-
pipeline = _pipeline_creation_args_debug(
2522-
backend='Torch',
2523-
cls=pipeline_class,
2524-
method=pipeline_class.from_single_file,
2525-
original_config=original_config,
2526-
model=model_path,
2527-
token=auth_token,
2528-
revision=revision,
2529-
variant=variant,
2530-
torch_dtype=torch_dtype,
2531-
use_safe_tensors=model_path.endswith('.safetensors'),
2532-
local_files_only=local_files_only,
2533-
**creation_kwargs)
2520+
pipeline = _pipeline_creation_args_debug(
2521+
backend='Torch',
2522+
cls=pipeline_class,
2523+
method=pipeline_class.from_single_file,
2524+
original_config=original_config,
2525+
model=model_path,
2526+
token=auth_token,
2527+
revision=revision,
2528+
variant=variant,
2529+
torch_dtype=torch_dtype,
2530+
use_safe_tensors=model_path.endswith('.safetensors'),
2531+
local_files_only=local_files_only,
2532+
**creation_kwargs)
25342533

25352534
except diffusers.loaders.single_file.SingleFileComponentError as e:
25362535
msg = str(e)

dgenerate/pipelinewrapper/util.py

Lines changed: 20 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -212,28 +212,27 @@ def single_file_load_sub_module(
212212
else:
213213
cached_model_config_path = default_pretrained_model_config_name
214214

215-
with _patch_sd21_clip_from_ldm():
216-
args = {
217-
"cached_model_config_path": cached_model_config_path,
218-
"library_name": library_name,
219-
"class_name": class_name,
220-
"name": name,
221-
"torch_dtype": dtype,
222-
"original_config": original_config,
223-
"local_files_only": local_files_only
224-
}
225-
226-
_messages.debug_log(
227-
f'Loading a "{class_name}" with: '
228-
f'diffusers.loaders.load_single_file_sub_model({args})'
229-
)
215+
args = {
216+
"cached_model_config_path": cached_model_config_path,
217+
"library_name": library_name,
218+
"class_name": class_name,
219+
"name": name,
220+
"torch_dtype": dtype,
221+
"original_config": original_config,
222+
"local_files_only": local_files_only
223+
}
230224

231-
model = _single_file.load_single_file_sub_model(
232-
**args,
233-
checkpoint=checkpoint,
234-
pipelines=diffusers.pipelines,
235-
is_pipeline_module=False
236-
)
225+
_messages.debug_log(
226+
f'Loading a "{class_name}" with: '
227+
f'diffusers.loaders.load_single_file_sub_model({args})'
228+
)
229+
230+
model = _single_file.load_single_file_sub_model(
231+
**args,
232+
checkpoint=checkpoint,
233+
pipelines=diffusers.pipelines,
234+
is_pipeline_module=False
235+
)
237236

238237
return model
239238

@@ -870,86 +869,4 @@ def is_single_file_model_load(path):
870869
return False
871870

872871

873-
@contextlib.contextmanager
874-
def _patch_sd21_clip_from_ldm():
875-
"""
876-
A context manager which temporarily patches diffusers clip / text_encoder model loading
877-
for single file checkpoints, this fixes loading SD2.1 CivitAI checkpoints
878-
with StableDiffusionPipeline.from_single_file, and also loading a text encoder
879-
individually via checkpoint extraction using dgenerate's TextEncoderUri class.
880-
"""
881-
og_func = diffusers.loaders.single_file_utils.convert_open_clip_checkpoint
882-
diffusers.loaders.single_file_utils.convert_open_clip_checkpoint = _convert_open_clip_checkpoint
883-
try:
884-
yield
885-
finally:
886-
diffusers.loaders.single_file_utils.convert_open_clip_checkpoint = og_func
887-
888-
889-
def _convert_open_clip_checkpoint(
890-
text_model,
891-
checkpoint,
892-
prefix="cond_stage_model.model.",
893-
):
894-
text_model_dict = {}
895-
text_proj_key = prefix + "text_projection"
896-
897-
if text_proj_key in checkpoint:
898-
text_proj_dim = int(checkpoint[text_proj_key].shape[1])
899-
elif hasattr(text_model.config, "hidden_size"):
900-
text_proj_dim = text_model.config.hidden_size
901-
else:
902-
text_proj_dim = _single_file_utils.LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
903-
904-
keys = list(checkpoint.keys())
905-
keys_to_ignore = _single_file_utils.SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
906-
907-
openclip_diffusers_ldm_map = _single_file_utils.DIFFUSERS_TO_LDM_MAPPING["openclip"]["layers"]
908-
for diffusers_key, ldm_key in openclip_diffusers_ldm_map.items():
909-
ldm_key = prefix + ldm_key
910-
if ldm_key not in checkpoint:
911-
continue
912-
if ldm_key in keys_to_ignore:
913-
continue
914-
if ldm_key.endswith("text_projection"):
915-
text_model_dict[diffusers_key] = checkpoint[ldm_key].T.contiguous()
916-
else:
917-
text_model_dict[diffusers_key] = checkpoint[ldm_key]
918-
919-
for key in keys:
920-
if key in keys_to_ignore:
921-
continue
922-
923-
if not key.startswith(prefix + "transformer."):
924-
continue
925-
926-
diffusers_key = key.replace(prefix + "transformer.", "")
927-
transformer_diffusers_to_ldm_map = _single_file_utils.DIFFUSERS_TO_LDM_MAPPING["openclip"]["transformer"]
928-
for new_key, old_key in transformer_diffusers_to_ldm_map.items():
929-
diffusers_key = (
930-
diffusers_key.replace(old_key, new_key).replace(".in_proj_weight", "").replace(".in_proj_bias", "")
931-
)
932-
933-
if key.endswith(".in_proj_weight"):
934-
weight_value = checkpoint.get(key)
935-
936-
text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach()
937-
text_model_dict[diffusers_key + ".k_proj.weight"] = (
938-
weight_value[text_proj_dim: text_proj_dim * 2, :].clone().detach()
939-
)
940-
text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2:, :].clone().detach()
941-
942-
elif key.endswith(".in_proj_bias"):
943-
weight_value = checkpoint.get(key)
944-
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach()
945-
text_model_dict[diffusers_key + ".k_proj.bias"] = (
946-
weight_value[text_proj_dim: text_proj_dim * 2].clone().detach()
947-
)
948-
text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2:].clone().detach()
949-
else:
950-
text_model_dict[diffusers_key] = checkpoint.get(key)
951-
952-
return text_model_dict
953-
954-
955872
__all__ = _types.module_all()

0 commit comments

Comments
 (0)