Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
18de3ad
run control-lora on diffusers
lavinal712 Jan 30, 2025
e9d91e1
cannot load lora adapter
lavinal712 Feb 1, 2025
9cf8ad7
test
lavinal712 Feb 4, 2025
2453e14
1
lavinal712 Feb 7, 2025
39b3b84
add control-lora
lavinal712 Feb 7, 2025
de61226
1
lavinal712 Feb 15, 2025
10daac7
1
lavinal712 Feb 15, 2025
523967f
1
lavinal712 Feb 15, 2025
dd24464
Merge branch 'huggingface:main' into control-lora
lavinal712 Feb 23, 2025
33288e6
Merge branch 'huggingface:main' into control-lora
lavinal712 Mar 17, 2025
280cf7f
Merge branch 'huggingface:main' into control-lora
lavinal712 Mar 23, 2025
7c25a06
fix PeftAdapterMixin
lavinal712 Mar 23, 2025
0719c20
fix module_to_save bug
lavinal712 Mar 23, 2025
81eed41
delete json print
lavinal712 Mar 23, 2025
2de1505
Merge branch 'main' into control-lora
sayakpaul Mar 25, 2025
ce2b34b
Merge branch 'main' into control-lora
lavinal712 Mar 26, 2025
6a1ff82
resolve conflits
lavinal712 Apr 9, 2025
ab9eeff
Merge branch 'main' into control-lora
lavinal712 Apr 9, 2025
6fff794
merged but bug
lavinal712 Apr 9, 2025
8f7fc0a
Merge branch 'huggingface:main' into control-lora
lavinal712 May 29, 2025
63bafc8
change peft.py
lavinal712 May 29, 2025
c134bca
change peft.py
lavinal712 May 29, 2025
39e9254
Merge branch 'huggingface:main' into control-lora
lavinal712 Jul 2, 2025
d752992
Merge branch 'huggingface:main' into control-lora
lavinal712 Jul 5, 2025
0a5bd74
1
lavinal712 Jul 5, 2025
53a06cc
delete state_dict print
lavinal712 Jul 5, 2025
23cba18
fix alpha
lavinal712 Jul 5, 2025
d3a0755
Merge branch 'main' into control-lora
lavinal712 Jul 21, 2025
af8255e
Merge branch 'main' into control-lora
lavinal712 Jul 30, 2025
c6c13b6
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 8, 2025
4a64d64
Merge branch 'main' into control-lora
lavinal712 Aug 14, 2025
59a42b2
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 17, 2025
1c90272
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 18, 2025
a2eff1c
Merge branch 'huggingface:main' into control-lora
lavinal712 Aug 20, 2025
00a26cd
Create control_lora.py
lavinal712 Aug 20, 2025
1e8221c
Add files via upload
lavinal712 Aug 20, 2025
9d94c37
Merge branch 'huggingface:main' into control-lora
lavinal712 Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def text_encoder_attn_modules(text_encoder):
"SD3IPAdapterMixin",
]

_import_structure["peft"] = ["PeftAdapterMixin"]
_import_structure["peft"] = ["PeftAdapterMixin", "ControlLoRAMixin"]


if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Expand All @@ -94,6 +94,7 @@ def text_encoder_attn_modules(text_encoder):
from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
from .peft import ControlLoRAMixin

if is_transformers_available():
from .ip_adapter import (
Expand All @@ -120,7 +121,7 @@ def text_encoder_attn_modules(text_encoder):
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin

from .peft import PeftAdapterMixin
from .peft import PeftAdapterMixin, ControlLoRAMixin
else:
import sys

Expand Down
182 changes: 182 additions & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
check_peft_version,
convert_control_lora_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
Expand Down Expand Up @@ -766,3 +767,184 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
# Pop also the corresponding adapter from the config
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)


class ControlLoRAMixin(PeftAdapterMixin):
TARGET_MODULES = ["to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2", "proj_in", "proj_out",
"conv", "conv1", "conv2", "conv_in", "conv_shortcut", "linear_1", "linear_2", "time_emb_proj"]
SAVE_MODULES = ["controlnet_cond_embedding.conv_in", "controlnet_cond_embedding.blocks.0",
"controlnet_cond_embedding.blocks.1", "controlnet_cond_embedding.blocks.2",
"controlnet_cond_embedding.blocks.3", "controlnet_cond_embedding.blocks.4",
"controlnet_cond_embedding.blocks.5", "controlnet_cond_embedding.conv_out",
"controlnet_down_blocks.0", "controlnet_down_blocks.1", "controlnet_down_blocks.2",
"controlnet_down_blocks.3", "controlnet_down_blocks.4", "controlnet_down_blocks.5",
"controlnet_down_blocks.6", "controlnet_down_blocks.7", "controlnet_down_blocks.8",
"controlnet_mid_block", "norm", "norm1", "norm2", "norm3"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During LoRA automated loading, it is necessary to specify the modules to be loaded, which is not possible in the original code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we show me where these are required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight_name.txt
This is the original weight file of Control-Lora. By comparing and analyzing the Diffusers format with its format, we concluded that we need to use LoRA to fine-tune certain modules while also training other modules. This implementation can also be found at https://github.com/lavinal712/control-lora-v3/blob/main/train_control_lora_sdxl.py.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I am asking where are these used in the PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lora_config_kwargs["target_modules"] = self.TARGET_MODULES
lora_config_kwargs["modules_to_save"] = self.SAVE_MODULES
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:  
    adapter_name = "default"

In line 868-873

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we should be able to infer that without having to directly specify it like this. This is what is done for the others:

def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function did not achieve the expected effect, so I resorted to modifying it forcefully to meet my purpose.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we need to find a way to tackle this problem. It should not deviate too much from how we go about loading other LoRAs.


def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer

cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
adapter_name = kwargs.pop("adapter_name", None)
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
allow_pickle = False

if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
if network_alphas is not None and prefix is None:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")

if prefix is not None:
keys = list(state_dict.keys())
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
if len(model_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}

if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
)

# check with first key if is not in peft format
if "lora_controlnet" in state_dict:
del state_dict["lora_controlnet"]
state_dict = convert_control_lora_state_dict_to_peft(state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this the only change that is required to make the

def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):

method work as expected on the Control LoRA state dict, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another change is to forcibly set the adapter_name to "default".

Copy link
Member

@sayakpaul sayakpaul Feb 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be a breaking change as we support loading multiple adapters. If this is the only change that is required, I think we can simply port it over to load_lora_adapters() of PeftAdapterMixin. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is reasonable. This will be addressed after resolving the other issues.


rank = {}
for key, val in state_dict.items():
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
rank[key] = val.shape[1]

if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")

if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

lora_config_kwargs["bias"] = "all"
lora_config_kwargs["target_modules"] = self.TARGET_MODULES
lora_config_kwargs["modules_to_save"] = self.SAVE_MODULES
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = "default"

# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to `_pipeline`.

# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)

peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
# we should also delete the `peft_config` associated to the `adapter_name`.
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
for module in self.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
for active_adapter in active_adapters:
if adapter_name in active_adapter:
module.delete_adapter(adapter_name)

self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
raise

warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)

# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
3 changes: 2 additions & 1 deletion src/diffusers/models/controlnets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.nn import functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, ControlLoRAMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import BaseOutput, logging
from ..attention_processor import (
Expand Down Expand Up @@ -106,7 +107,7 @@ def forward(self, conditioning):
return embedding


class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, ControlLoRAMixin):
"""
A ControlNet model.

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
convert_state_dict_to_kohya,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
convert_control_lora_state_dict_to_peft,
)
from .typing_utils import _get_detailed_type, _is_valid_type

Expand Down
Loading