Skip to content
13 changes: 10 additions & 3 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,20 @@ def load_lora_adapter(
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}

# create LoraConfig
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)

# create LoraConfig
lora_config = _create_lora_config(
state_dict,
network_alphas,
metadata,
rank,
model_state_dict=self.state_dict(),
adapter_name=adapter_name,
)

# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to `_pipeline`.
Expand Down
41 changes: 33 additions & 8 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)


def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
def get_peft_kwargs(
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
Expand Down Expand Up @@ -180,18 +182,23 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
else:
lora_alpha = set(network_alpha_dict.values()).pop()

# layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
# for now we know that the "bias" keys are only associated with `lora_B`.
lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)

# Example: load FusionX LoRA into Wan VACE
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
if not exclude_modules:
exclude_modules = None

lora_config_kwargs = {
"r": r,
"lora_alpha": lora_alpha,
"rank_pattern": rank_pattern,
"alpha_pattern": alpha_pattern,
"target_modules": target_modules,
"exclude_modules": exclude_modules,
"use_dora": use_dora,
"lora_bias": lora_bias,
}
Expand Down Expand Up @@ -294,19 +301,20 @@ def check_peft_version(min_version: str) -> None:


def _create_lora_config(
state_dict,
network_alphas,
metadata,
rank_pattern_dict,
is_unet: bool = True,
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
from peft import LoraConfig

if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
rank_pattern_dict,
network_alpha_dict=network_alphas,
peft_state_dict=state_dict,
is_unet=is_unet,
model_state_dict=model_state_dict,
adapter_name=adapter_name,
)

_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
Expand Down Expand Up @@ -371,3 +379,20 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):

if warn_msg:
logger.warning(warn_msg)


def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
all_modules = set()
string_to_replace = f"{adapter_name}." if adapter_name else ""

for name in model_state_dict.keys():
if string_to_replace:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the if-statement is not needed here because string_to_replace will be an empty string, but no problem keeping as micro optimization

name = name.replace(string_to_replace, "")
if "." in name:
module_name = name.rsplit(".", 1)[0]
all_modules.add(module_name)

target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
exclude_modules = list(all_modules - target_modules_set)

return exclude_modules
Loading