Skip to content
Merged
63 changes: 12 additions & 51 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
delete_adapter_layers,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
Expand All @@ -46,14 +45,13 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.peft_utils import _create_lora_config
from ..utils.state_dict_utils import _load_sft_state_dict_metadata


if is_transformers_available():
from transformers import PreTrainedModel

from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules

if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer

Expand Down Expand Up @@ -352,8 +350,6 @@ def _load_lora_into_text_encoder(
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

from peft import LoraConfig

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
Expand All @@ -377,60 +373,25 @@ def _load_lora_into_text_encoder(
# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)

for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]

for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
for name, _ in text_encoder.named_modules():
if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
rank_key = f"{name}.lora_B.weight"
if rank_key in state_dict:
rank[rank_key] = state_dict[rank_key].shape[1]

if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}

if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")

if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# create `LoraConfig`
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)

# <Unsafe code
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)

# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
Expand All @@ -442,7 +403,6 @@ def _load_lora_into_text_encoder(

# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)

text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)

# Offload back.
Expand All @@ -453,10 +413,11 @@ def _load_lora_into_text_encoder(
# Unsafe code />

if prefix is not None and not state_dict:
model_class_name = text_encoder.__class__.__name__
logger.warning(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
f"{model_class_name} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
Expand Down
92 changes: 9 additions & 83 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
is_peft_available,
is_peft_version,
logging,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales

Expand Down Expand Up @@ -64,26 +64,6 @@
}


def _maybe_raise_error_for_ambiguity(config):
rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"]

for key in list(rank_pattern.keys()):
# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
# over the chars of the str. The technically correct way to match LoRA keys
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
# But this cuts it for now.
exact_matches = [mod for mod in target_modules if mod == key]
substring_matches = [mod for mod in target_modules if key in mod and mod != key]

if exact_matches and substring_matches:
if is_peft_version("<", "0.14.1"):
raise ValueError(
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
)


class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
Expand Down Expand Up @@ -191,7 +171,7 @@ def load_lora_adapter(
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
initialize `LoraConfig`.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft import inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer

cache_dir = kwargs.pop("cache_dir", None)
Expand All @@ -216,7 +196,6 @@ def load_lora_adapter(
)

user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}

state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
Expand Down Expand Up @@ -275,38 +254,8 @@ def load_lora_adapter(
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}

if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")

if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# create LoraConfig
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)

# adapter_name
if adapter_name is None:
Expand All @@ -317,9 +266,8 @@ def load_lora_adapter(
# Now we remove any existing hooks to `_pipeline`.

# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
# otherwise loading LoRA weights will lead to an error.
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)

peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
Expand Down Expand Up @@ -403,30 +351,7 @@ def map_state_dict_for_hotswap(sd):
logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
raise

warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)
_maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name)

# Offload back.
if is_model_cpu_offload:
Expand All @@ -436,10 +361,11 @@ def map_state_dict_for_hotswap(sd):
# Unsafe code />

if prefix is not None and not state_dict:
model_class_name = self.__class__.__name__
logger.warning(
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
f"{model_class_name} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
Expand Down
85 changes: 84 additions & 1 deletion src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

from packaging import version

from .import_utils import is_peft_available, is_torch_available
from . import logging
from .import_utils import is_peft_available, is_peft_version, is_torch_available


logger = logging.get_logger(__name__)

if is_torch_available():
import torch

Expand Down Expand Up @@ -288,3 +291,83 @@ def check_peft_version(min_version: str) -> None:
f"The version of PEFT you are using is not compatible, please use a version that is greater"
f" than {min_version}"
)


def _create_lora_config(
state_dict,
network_alphas,
metadata,
rank_pattern_dict,
is_unet: bool = True,
):
from peft import LoraConfig

if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
)

_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)

# Version checks for DoRA and lora_bias
if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.")

if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.")

try:
return LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e


def _maybe_raise_error_for_ambiguous_keys(config):
rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"]

for key in list(rank_pattern.keys()):
# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
# over the chars of the str. The technically correct way to match LoRA keys
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
# But this cuts it for now.
exact_matches = [mod for mod in target_modules if mod == key]
substring_matches = [mod for mod in target_modules if key in mod and mod != key]

if exact_matches and substring_matches:
if is_peft_version("<", "0.14.1"):
raise ValueError(
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
)


def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)
Loading