Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
18de3ad
run control-lora on diffusers
lavinal712 Jan 30, 2025
e9d91e1
cannot load lora adapter
lavinal712 Feb 1, 2025
9cf8ad7
test
lavinal712 Feb 4, 2025
2453e14
1
lavinal712 Feb 7, 2025
39b3b84
add control-lora
lavinal712 Feb 7, 2025
de61226
1
lavinal712 Feb 15, 2025
10daac7
1
lavinal712 Feb 15, 2025
523967f
1
lavinal712 Feb 15, 2025
dd24464
Merge branch 'huggingface:main' into control-lora
lavinal712 Feb 23, 2025
33288e6
Merge branch 'huggingface:main' into control-lora
lavinal712 Mar 17, 2025
280cf7f
Merge branch 'huggingface:main' into control-lora
lavinal712 Mar 23, 2025
7c25a06
fix PeftAdapterMixin
lavinal712 Mar 23, 2025
0719c20
fix module_to_save bug
lavinal712 Mar 23, 2025
81eed41
delete json print
lavinal712 Mar 23, 2025
2de1505
Merge branch 'main' into control-lora
sayakpaul Mar 25, 2025
ce2b34b
Merge branch 'main' into control-lora
lavinal712 Mar 26, 2025
6a1ff82
resolve conflits
lavinal712 Apr 9, 2025
ab9eeff
Merge branch 'main' into control-lora
lavinal712 Apr 9, 2025
6fff794
merged but bug
lavinal712 Apr 9, 2025
8f7fc0a
Merge branch 'huggingface:main' into control-lora
lavinal712 May 29, 2025
63bafc8
change peft.py
lavinal712 May 29, 2025
c134bca
change peft.py
lavinal712 May 29, 2025
39e9254
Merge branch 'huggingface:main' into control-lora
lavinal712 Jul 2, 2025
d752992
Merge branch 'huggingface:main' into control-lora
lavinal712 Jul 5, 2025
0a5bd74
1
lavinal712 Jul 5, 2025
53a06cc
delete state_dict print
lavinal712 Jul 5, 2025
23cba18
fix alpha
lavinal712 Jul 5, 2025
d3a0755
Merge branch 'main' into control-lora
lavinal712 Jul 21, 2025
af8255e
Merge branch 'main' into control-lora
lavinal712 Jul 30, 2025
c6c13b6
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 8, 2025
4a64d64
Merge branch 'main' into control-lora
lavinal712 Aug 14, 2025
59a42b2
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 17, 2025
1c90272
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 18, 2025
a2eff1c
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 20, 2025
00a26cd
Create control_lora.py
lavinal712 Aug 20, 2025
1e8221c
Add files via upload
lavinal712 Aug 20, 2025
9d94c37
Merge branch 'huggingface:main' into control-lora
lavinal712 Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
check_peft_version,
convert_control_lora_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
Expand Down Expand Up @@ -81,6 +82,33 @@ def _maybe_raise_error_for_ambiguity(config):
)


def _maybe_adjust_config_for_control_lora(config):
"""
"""

target_modules_before = config["target_modules"]
target_modules = []
modules_to_save = []

for module in target_modules_before:
if module.endswith("weight"):
base_name = ".".join(module.split(".")[:-1])
modules_to_save.append(base_name)
elif module.endswith("bias"):
base_name = ".".join(module.split(".")[:-1])
if ".".join([base_name, "weight"]) in target_modules_before:
modules_to_save.append(base_name)
else:
target_modules.append(base_name)
else:
target_modules.append(module)

config["target_modules"] = list(set(target_modules))
config["modules_to_save"] = list(set(modules_to_save))

return config


class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
Expand Down Expand Up @@ -244,6 +272,13 @@ def load_lora_adapter(
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
)

# Control LoRA from SAI is different from BFL Control LoRA
# https://huggingface.co/stabilityai/control-lora/
is_control_lora = "lora_controlnet" in state_dict
if is_control_lora:
del state_dict["lora_controlnet"]
state_dict = convert_control_lora_state_dict_to_peft(state_dict)

# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
Expand All @@ -268,6 +303,8 @@ def load_lora_adapter(

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
if is_control_lora:
lora_config_kwargs = _maybe_adjust_config_for_control_lora(lora_config_kwargs)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/controlnets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.nn import functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import BaseOutput, logging
from ..attention_processor import (
Expand Down Expand Up @@ -106,7 +107,7 @@ def forward(self, conditioning):
return embedding


class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
"""
A ControlNet model.

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
from .remote_utils import remote_decode
from .state_dict_utils import (
convert_all_state_dict_to_peft,
convert_control_lora_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_state_dict_to_peft,
Expand Down
179 changes: 179 additions & 0 deletions src/diffusers/utils/state_dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,36 @@ class StateDictType(enum.Enum):
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
}

CONTROL_LORA_TO_DIFFUSERS = {
".to_q.down": ".to_q.lora_A.weight",
".to_q.up": ".to_q.lora_B.weight",
".to_k.down": ".to_k.lora_A.weight",
".to_k.up": ".to_k.lora_B.weight",
".to_v.down": ".to_v.lora_A.weight",
".to_v.up": ".to_v.lora_B.weight",
".to_out.0.down": ".to_out.0.lora_A.weight",
".to_out.0.up": ".to_out.0.lora_B.weight",
".ff.net.0.proj.down": ".ff.net.0.proj.lora_A.weight",
".ff.net.0.proj.up": ".ff.net.0.proj.lora_B.weight",
".ff.net.2.down": ".ff.net.2.lora_A.weight",
".ff.net.2.up": ".ff.net.2.lora_B.weight",
".proj_in.down": ".proj_in.lora_A.weight",
".proj_in.up": ".proj_in.lora_B.weight",
".proj_out.down": ".proj_out.lora_A.weight",
".proj_out.up": ".proj_out.lora_B.weight",
".conv.down": ".conv.lora_A.weight",
".conv.up": ".conv.lora_B.weight",
**{f".conv{i}.down": f".conv{i}.lora_A.weight" for i in range(1, 3)},
**{f".conv{i}.up": f".conv{i}.lora_B.weight" for i in range(1, 3)},
"conv_in.down": "conv_in.lora_A.weight",
"conv_in.up": "conv_in.lora_B.weight",
".conv_shortcut.down": ".conv_shortcut.lora_A.weight",
".conv_shortcut.up": ".conv_shortcut.lora_B.weight",
**{f".linear_{i}.down": f".linear_{i}.lora_A.weight" for i in range(1, 3)},
**{f".linear_{i}.up": f".linear_{i}.lora_B.weight" for i in range(1, 3)},
"time_emb_proj.down": "time_emb_proj.lora_A.weight",
"time_emb_proj.up": "time_emb_proj.lora_B.weight",
}

DIFFUSERS_TO_PEFT = {
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
Expand Down Expand Up @@ -258,6 +288,155 @@ def convert_unet_state_dict_to_peft(state_dict):
return convert_state_dict(state_dict, mapping)


def convert_control_lora_state_dict_to_peft(state_dict):
def _convert_controlnet_to_diffusers(state_dict):
is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict
logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})")

# Retrieves the keys for the input blocks only
num_input_blocks = len(
{".".join(layer.split(".")[:2]) for layer in state_dict if "input_blocks" in layer}
)
input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
layers_per_block = 2

# op blocks
op_blocks = [key for key in state_dict if "0.op" in key]

converted_state_dict = {}
# Conv in layers
for key in input_blocks[0]:
diffusers_key = key.replace("input_blocks.0.0", "conv_in")
converted_state_dict[diffusers_key] = state_dict.get(key)

# controlnet time embedding blocks
time_embedding_blocks = [key for key in state_dict if "time_embed" in key]
for key in time_embedding_blocks:
diffusers_key = (key.replace("time_embed.0", "time_embedding.linear_1")
.replace("time_embed.2", "time_embedding.linear_2")
)
converted_state_dict[diffusers_key] = state_dict.get(key)

# controlnet label embedding blocks
label_embedding_blocks = [key for key in state_dict if "label_emb" in key]
for key in label_embedding_blocks:
diffusers_key = (key.replace("label_emb.0.0", "add_embedding.linear_1")
.replace("label_emb.0.2", "add_embedding.linear_2")
)
converted_state_dict[diffusers_key] = state_dict.get(key)

# Down blocks
for i in range(1, num_input_blocks):
block_id = (i - 1) // (layers_per_block + 1)
layer_in_block_id = (i - 1) % (layers_per_block + 1)

resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
for key in resnets:
diffusers_key = (key.replace("in_layers.0", "norm1")
.replace("in_layers.2", "conv1")
.replace("out_layers.0", "norm2")
.replace("out_layers.3", "conv2")
.replace("emb_layers.1", "time_emb_proj")
.replace("skip_connection", "conv_shortcut")
)
diffusers_key = diffusers_key.replace(
f"input_blocks.{i}.0", f"down_blocks.{block_id}.resnets.{layer_in_block_id}"
)
converted_state_dict[diffusers_key] = state_dict.get(key)

if f"input_blocks.{i}.0.op.bias" in state_dict:
for key in [key for key in op_blocks if f"input_blocks.{i}.0.op" in key]:
diffusers_key = key.replace(f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv")
converted_state_dict[diffusers_key] = state_dict.get(key)

attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if attentions:
for key in attentions:
diffusers_key = key.replace(
f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
)
converted_state_dict[diffusers_key] = state_dict.get(key)

# controlnet down blocks
for i in range(num_input_blocks):
converted_state_dict[f"controlnet_down_blocks.{i}.weight"] = state_dict.get(f"zero_convs.{i}.0.weight")
converted_state_dict[f"controlnet_down_blocks.{i}.bias"] = state_dict.get(f"zero_convs.{i}.0.bias")

# Retrieves the keys for the middle blocks only
num_middle_blocks = len(
{".".join(layer.split(".")[:2]) for layer in state_dict if "middle_block" in layer}
)
middle_blocks = {
layer_id: [key for key in state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}

# Mid blocks
for key in middle_blocks.keys():
diffusers_key = max(key - 1, 0)
if key % 2 == 0:
for k in middle_blocks[key]:
diffusers_key_hf = (k.replace("in_layers.0", "norm1")
.replace("in_layers.2", "conv1")
.replace("out_layers.0", "norm2")
.replace("out_layers.3", "conv2")
.replace("emb_layers.1", "time_emb_proj")
.replace("skip_connection", "conv_shortcut")
)
diffusers_key_hf = diffusers_key_hf.replace(
f"middle_block.{key}", f"mid_block.resnets.{diffusers_key}"
)
converted_state_dict[diffusers_key_hf] = state_dict.get(k)
else:
for k in middle_blocks[key]:
diffusers_key_hf = k.replace(
f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}"
)
converted_state_dict[diffusers_key_hf] = state_dict.get(k)

# mid block
converted_state_dict["controlnet_mid_block.weight"] = state_dict.get("middle_block_out.0.weight")
converted_state_dict["controlnet_mid_block.bias"] = state_dict.get("middle_block_out.0.bias")

# controlnet cond embedding blocks
cond_embedding_blocks = {
".".join(layer.split(".")[:2])
for layer in state_dict
if "input_hint_block" in layer and ("input_hint_block.0" not in layer) and ("input_hint_block.14" not in layer)
}
num_cond_embedding_blocks = len(cond_embedding_blocks)

for idx in range(1, num_cond_embedding_blocks + 1):
diffusers_idx = idx - 1
cond_block_id = 2 * idx

converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = state_dict.get(
f"input_hint_block.{cond_block_id}.weight"
)
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = state_dict.get(
f"input_hint_block.{cond_block_id}.bias"
)

for key in [key for key in state_dict if "input_hint_block.0" in key]:
diffusers_key = key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in")
converted_state_dict[diffusers_key] = state_dict.get(key)

for key in [key for key in state_dict if "input_hint_block.14" in key]:
diffusers_key = key.replace(f"input_hint_block.14", "controlnet_cond_embedding.conv_out")
converted_state_dict[diffusers_key] = state_dict.get(key)

return converted_state_dict

state_dict = _convert_controlnet_to_diffusers(state_dict)
mapping = CONTROL_LORA_TO_DIFFUSERS
return convert_state_dict(state_dict, mapping)


def convert_all_state_dict_to_peft(state_dict):
r"""
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
Expand Down