Skip to content
Merged
96 changes: 22 additions & 74 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
delete_adapter_layers,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
Expand All @@ -46,14 +45,13 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.peft_utils import _create_lora_config, _lora_loading_context, _maybe_warn_if_no_keys_found
from ..utils.state_dict_utils import _load_sft_state_dict_metadata


if is_transformers_available():
from transformers import PreTrainedModel

from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules

if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer

Expand Down Expand Up @@ -352,8 +350,6 @@ def _load_lora_into_text_encoder(
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

from peft import LoraConfig

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
Expand All @@ -377,89 +373,41 @@ def _load_lora_into_text_encoder(
# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)

for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]

for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
for name, _ in text_encoder.named_modules():
if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
rank_key = f"{name}.lora_B.weight"
if rank_key in state_dict:
rank[rank_key] = state_dict[rank_key].shape[1]

if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}

if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")

if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# create `LoraConfig`
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)

is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)

# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=state_dict,
peft_config=lora_config,
**peft_kwargs,
)

# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
# <Unsafe code
with _lora_loading_context(_pipeline):
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=state_dict,
peft_config=lora_config,
**peft_kwargs,
)

text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)

# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Unsafe code />

if prefix is not None and not state_dict:
logger.warning(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
_maybe_warn_if_no_keys_found(state_dict, prefix, model_class_name=text_encoder.__class__.__name__)


def _func_optionally_disable_offloading(_pipeline):
Expand Down
Loading