|
62 | 62 | "xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias", |
63 | 63 | "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", |
64 | 64 | "upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias", |
65 | | - "controlnet": "control_model.time_embed.0.weight", |
| 65 | + "controlnet": [ |
| 66 | + "control_model.time_embed.0.weight", |
| 67 | + "controlnet_cond_embedding.conv_in.weight", |
| 68 | + ], |
| 69 | + # TODO: find non-Diffusers keys for controlnet_xl |
| 70 | + "controlnet_xl": "add_embedding.linear_1.weight", |
| 71 | + "controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight", |
| 72 | + "controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight", |
66 | 73 | "playground-v2-5": "edm_mean", |
67 | 74 | "inpainting": "model.diffusion_model.input_blocks.0.0.weight", |
68 | 75 | "clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", |
|
96 | 103 | "inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"}, |
97 | 104 | "inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"}, |
98 | 105 | "controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"}, |
| 106 | + "controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"}, |
| 107 | + "controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"}, |
| 108 | + "controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"}, |
99 | 109 | "v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"}, |
100 | 110 | "v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"}, |
101 | 111 | "stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"}, |
@@ -481,8 +491,16 @@ def infer_diffusers_model_type(checkpoint): |
481 | 491 | elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint: |
482 | 492 | model_type = "upscale" |
483 | 493 |
|
484 | | - elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint: |
485 | | - model_type = "controlnet" |
| 494 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]): |
| 495 | + if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint: |
| 496 | + if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint: |
| 497 | + model_type = "controlnet_xl_large" |
| 498 | + elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint: |
| 499 | + model_type = "controlnet_xl_mid" |
| 500 | + else: |
| 501 | + model_type = "controlnet_xl_small" |
| 502 | + else: |
| 503 | + model_type = "controlnet" |
486 | 504 |
|
487 | 505 | elif ( |
488 | 506 | CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint |
@@ -1072,6 +1090,9 @@ def convert_controlnet_checkpoint( |
1072 | 1090 | config, |
1073 | 1091 | **kwargs, |
1074 | 1092 | ): |
| 1093 | + # Return checkpoint if it's already been converted |
| 1094 | + if "time_embedding.linear_1.weight" in checkpoint: |
| 1095 | + return checkpoint |
1075 | 1096 | # Some controlnet ckpt files are distributed independently from the rest of the |
1076 | 1097 | # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ |
1077 | 1098 | if "time_embed.0.weight" in checkpoint: |
|
0 commit comments