Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
18de3ad
run control-lora on diffusers
lavinal712 Jan 30, 2025
e9d91e1
cannot load lora adapter
lavinal712 Feb 1, 2025
9cf8ad7
test
lavinal712 Feb 4, 2025
2453e14
1
lavinal712 Feb 7, 2025
39b3b84
add control-lora
lavinal712 Feb 7, 2025
de61226
1
lavinal712 Feb 15, 2025
10daac7
1
lavinal712 Feb 15, 2025
523967f
1
lavinal712 Feb 15, 2025
dd24464
Merge branch 'huggingface:main' into control-lora
lavinal712 Feb 23, 2025
33288e6
Merge branch 'huggingface:main' into control-lora
lavinal712 Mar 17, 2025
280cf7f
Merge branch 'huggingface:main' into control-lora
lavinal712 Mar 23, 2025
7c25a06
fix PeftAdapterMixin
lavinal712 Mar 23, 2025
0719c20
fix module_to_save bug
lavinal712 Mar 23, 2025
81eed41
delete json print
lavinal712 Mar 23, 2025
2de1505
Merge branch 'main' into control-lora
sayakpaul Mar 25, 2025
ce2b34b
Merge branch 'main' into control-lora
lavinal712 Mar 26, 2025
6a1ff82
resolve conflits
lavinal712 Apr 9, 2025
ab9eeff
Merge branch 'main' into control-lora
lavinal712 Apr 9, 2025
6fff794
merged but bug
lavinal712 Apr 9, 2025
8f7fc0a
Merge branch 'huggingface:main' into control-lora
lavinal712 May 29, 2025
63bafc8
change peft.py
lavinal712 May 29, 2025
c134bca
change peft.py
lavinal712 May 29, 2025
39e9254
Merge branch 'huggingface:main' into control-lora
lavinal712 Jul 2, 2025
d752992
Merge branch 'huggingface:main' into control-lora
lavinal712 Jul 5, 2025
0a5bd74
1
lavinal712 Jul 5, 2025
53a06cc
delete state_dict print
lavinal712 Jul 5, 2025
23cba18
fix alpha
lavinal712 Jul 5, 2025
d3a0755
Merge branch 'main' into control-lora
lavinal712 Jul 21, 2025
af8255e
Merge branch 'main' into control-lora
lavinal712 Jul 30, 2025
c6c13b6
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 8, 2025
4a64d64
Merge branch 'main' into control-lora
lavinal712 Aug 14, 2025
59a42b2
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 17, 2025
1c90272
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 18, 2025
a2eff1c
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 20, 2025
00a26cd
Create control_lora.py
lavinal712 Aug 20, 2025
1e8221c
Add files via upload
lavinal712 Aug 20, 2025
9d94c37
Merge branch 'huggingface:main' into control-lora
lavinal712 Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,150 @@ def convert_controlnet_checkpoint(
return new_checkpoint


def convert_control_lora_checkpoint(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

checkpoint,
config,
**kwargs,
):
# Return checkpoint if it's already been converted
if "time_embedding.linear_1.weight" in checkpoint:
return checkpoint
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
if "time_embed.0.weight" in checkpoint:
controlnet_state_dict = checkpoint

else:
controlnet_state_dict = {}
keys = list(checkpoint.keys())
controlnet_key = LDM_CONTROLNET_KEY
for key in keys:
if key.startswith(controlnet_key):
controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key)
else:
controlnet_state_dict[key] = checkpoint.get(key)

new_checkpoint = {}
ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"]
for diffusers_key, ldm_key in ldm_controlnet_keys.items():
if ldm_key not in controlnet_state_dict:
continue
new_checkpoint[diffusers_key] = controlnet_state_dict[ldm_key]
for k, v in controlnet_state_dict.items():
if "time_embed.0" in k:
new_checkpoint[k.replace("time_embed.0", "time_embedding.linear_1")] = v
elif "time_embed.2" in k:
new_checkpoint[k.replace("time_embed.2", "time_embedding.linear_2")] = v
elif "input_blocks.0.0" in k:
new_checkpoint[k.replace("input_blocks.0.0", "conv_in")] = v
elif "label_emb.0.0" in k:
new_checkpoint[k.replace("label_emb.0.0", "add_embedding.linear_1")] = v
elif "label_emb.0.2" in k:
new_checkpoint[k.replace("label_emb.0.2", "add_embedding.linear_2")] = v
elif "input_blocks.3.0.op" in k:
new_checkpoint[k.replace("input_blocks.3.0.op", "down_blocks.0.downsamplers.0.conv")] = v
elif "input_blocks.6.0.op" in k:
new_checkpoint[k.replace("input_blocks.6.0.op", "down_blocks.1.downsamplers.0.conv")] = v

# Retrieves the keys for the input blocks only
num_input_blocks = len(
{".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "input_blocks" in layer}
)
input_blocks = {
layer_id: [key for key in controlnet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}

# Down blocks
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)

resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
update_unet_resnet_ldm_to_diffusers(
resnets,
new_checkpoint,
controlnet_state_dict,
{"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"},
)

if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get(
f"input_blocks.{i}.0.op.bias"
)

attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if attentions:
update_unet_attention_ldm_to_diffusers(
attentions,
new_checkpoint,
controlnet_state_dict,
{"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"},
)

# controlnet down blocks
for i in range(num_input_blocks):
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias")

# Retrieves the keys for the middle blocks only
num_middle_blocks = len(
{".".join(layer.split(".")[:2]) for layer in controlnet_state_dict if "middle_block" in layer}
)
middle_blocks = {
layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}

# Mid blocks
for key in middle_blocks.keys():
diffusers_key = max(key - 1, 0)
if key % 2 == 0:
update_unet_resnet_ldm_to_diffusers(
middle_blocks[key],
new_checkpoint,
controlnet_state_dict,
mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
)
else:
update_unet_attention_ldm_to_diffusers(
middle_blocks[key],
new_checkpoint,
controlnet_state_dict,
mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
)

# mid block
new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias")

# controlnet cond embedding blocks
cond_embedding_blocks = {
".".join(layer.split(".")[:2])
for layer in controlnet_state_dict
if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
}
num_cond_embedding_blocks = len(cond_embedding_blocks)

for idx in range(1, num_cond_embedding_blocks + 1):
diffusers_idx = idx - 1
cond_block_id = 2 * idx

new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get(
f"input_hint_block.{cond_block_id}.weight"
)
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get(
f"input_hint_block.{cond_block_id}.bias"
)

return new_checkpoint


def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
Expand Down
Loading