Skip to content
Merged
67 changes: 10 additions & 57 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
delete_adapter_layers,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
Expand All @@ -46,14 +45,13 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.peft_utils import _create_lora_config, _maybe_warn_if_no_keys_found
from ..utils.state_dict_utils import _load_sft_state_dict_metadata


if is_transformers_available():
from transformers import PreTrainedModel

from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules

if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer

Expand Down Expand Up @@ -352,8 +350,6 @@ def _load_lora_into_text_encoder(
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

from peft import LoraConfig

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
Expand All @@ -377,60 +373,25 @@ def _load_lora_into_text_encoder(
# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)

for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]

for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
for name, _ in text_encoder.named_modules():
if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
rank_key = f"{name}.lora_B.weight"
if rank_key in state_dict:
rank[rank_key] = state_dict[rank_key].shape[1]

if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}

if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")

if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# create `LoraConfig`
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)

# <Unsafe code
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)

# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
Expand All @@ -442,7 +403,6 @@ def _load_lora_into_text_encoder(

# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)

text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)

# Offload back.
Expand All @@ -452,14 +412,7 @@ def _load_lora_into_text_encoder(
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

if prefix is not None and not state_dict:
logger.warning(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
_maybe_warn_if_no_keys_found(state_dict, prefix, model_class_name=text_encoder.__class__.__name__)


def _func_optionally_disable_offloading(_pipeline):
Expand Down
100 changes: 11 additions & 89 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
is_peft_available,
is_peft_version,
logging,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.peft_utils import (
_create_lora_config,
_maybe_warn_for_unhandled_keys,
_maybe_warn_if_no_keys_found,
)
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales

Expand Down Expand Up @@ -64,26 +68,6 @@
}


def _maybe_raise_error_for_ambiguity(config):
rank_pattern = config["rank_pattern"].copy()
target_modules = config["target_modules"]

for key in list(rank_pattern.keys()):
# try to detect ambiguity
# `target_modules` can also be a str, in which case this loop would loop
# over the chars of the str. The technically correct way to match LoRA keys
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
# But this cuts it for now.
exact_matches = [mod for mod in target_modules if mod == key]
substring_matches = [mod for mod in target_modules if key in mod and mod != key]

if exact_matches and substring_matches:
if is_peft_version("<", "0.14.1"):
raise ValueError(
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
)


class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
Expand Down Expand Up @@ -191,7 +175,7 @@ def load_lora_adapter(
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
initialize `LoraConfig`.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft import inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer

cache_dir = kwargs.pop("cache_dir", None)
Expand All @@ -216,7 +200,6 @@ def load_lora_adapter(
)

user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}

state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
Expand Down Expand Up @@ -275,38 +258,8 @@ def load_lora_adapter(
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}

if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")

if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# create LoraConfig
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)

# adapter_name
if adapter_name is None:
Expand All @@ -317,9 +270,8 @@ def load_lora_adapter(
# Now we remove any existing hooks to `_pipeline`.

# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
# otherwise loading LoRA weights will lead to an error.
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)

peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
Expand Down Expand Up @@ -403,30 +355,7 @@ def map_state_dict_for_hotswap(sd):
logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
raise

warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)

# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)

if warn_msg:
logger.warning(warn_msg)
_maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name)

# Offload back.
if is_model_cpu_offload:
Expand All @@ -435,14 +364,7 @@ def map_state_dict_for_hotswap(sd):
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />

if prefix is not None and not state_dict:
logger.warning(
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
_maybe_warn_if_no_keys_found(state_dict, prefix, model_class_name=self.__class__.__name__)

def save_lora_adapter(
self,
Expand Down
Loading