Skip to content

Commit 4cc1ae8

Browse files
committed
simplifying text encoder lora loading.
1 parent 32bc07d commit 4cc1ae8

File tree

1 file changed

+22
-74
lines changed

1 file changed

+22
-74
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 22 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
delete_adapter_layers,
3535
deprecate,
3636
get_adapter_name,
37-
get_peft_kwargs,
3837
is_accelerate_available,
3938
is_peft_available,
4039
is_peft_version,
@@ -46,14 +45,13 @@
4645
set_adapter_layers,
4746
set_weights_and_activate_adapters,
4847
)
48+
from ..utils.peft_utils import _create_lora_config, _lora_loading_context, _maybe_warn_if_no_keys_found
4949
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
5050

5151

5252
if is_transformers_available():
5353
from transformers import PreTrainedModel
5454

55-
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
56-
5755
if is_peft_available():
5856
from peft.tuners.tuners_utils import BaseTunerLayer
5957

@@ -352,8 +350,6 @@ def _load_lora_into_text_encoder(
352350
)
353351
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
354352

355-
from peft import LoraConfig
356-
357353
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
358354
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
359355
# their prefixes.
@@ -377,89 +373,41 @@ def _load_lora_into_text_encoder(
377373
# convert state dict
378374
state_dict = convert_state_dict_to_peft(state_dict)
379375

380-
for name, _ in text_encoder_attn_modules(text_encoder):
381-
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
382-
rank_key = f"{name}.{module}.lora_B.weight"
383-
if rank_key not in state_dict:
384-
continue
385-
rank[rank_key] = state_dict[rank_key].shape[1]
386-
387-
for name, _ in text_encoder_mlp_modules(text_encoder):
388-
for module in ("fc1", "fc2"):
389-
rank_key = f"{name}.{module}.lora_B.weight"
390-
if rank_key not in state_dict:
391-
continue
392-
rank[rank_key] = state_dict[rank_key].shape[1]
376+
for name, _ in text_encoder.named_modules():
377+
if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
378+
rank_key = f"{name}.lora_B.weight"
379+
if rank_key in state_dict:
380+
rank[rank_key] = state_dict[rank_key].shape[1]
393381

394382
if network_alphas is not None:
395383
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
396384
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
397385

398-
if metadata is not None:
399-
lora_config_kwargs = metadata
400-
else:
401-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
402-
403-
if "use_dora" in lora_config_kwargs:
404-
if lora_config_kwargs["use_dora"]:
405-
if is_peft_version("<", "0.9.0"):
406-
raise ValueError(
407-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
408-
)
409-
else:
410-
if is_peft_version("<", "0.9.0"):
411-
lora_config_kwargs.pop("use_dora")
412-
413-
if "lora_bias" in lora_config_kwargs:
414-
if lora_config_kwargs["lora_bias"]:
415-
if is_peft_version("<=", "0.13.2"):
416-
raise ValueError(
417-
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
418-
)
419-
else:
420-
if is_peft_version("<=", "0.13.2"):
421-
lora_config_kwargs.pop("lora_bias")
422-
423-
try:
424-
lora_config = LoraConfig(**lora_config_kwargs)
425-
except TypeError as e:
426-
raise TypeError("`LoraConfig` class could not be instantiated.") from e
386+
# create `LoraConfig`
387+
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)
427388

428389
# adapter_name
429390
if adapter_name is None:
430391
adapter_name = get_adapter_name(text_encoder)
431392

432-
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
433-
434-
# inject LoRA layers and load the state dict
435-
# in transformers we automatically check whether the adapter name is already in use or not
436-
text_encoder.load_adapter(
437-
adapter_name=adapter_name,
438-
adapter_state_dict=state_dict,
439-
peft_config=lora_config,
440-
**peft_kwargs,
441-
)
442-
443-
# scale LoRA layers with `lora_scale`
444-
scale_lora_layers(text_encoder, weight=lora_scale)
393+
# <Unsafe code
394+
with _lora_loading_context(_pipeline):
395+
# inject LoRA layers and load the state dict
396+
# in transformers we automatically check whether the adapter name is already in use or not
397+
text_encoder.load_adapter(
398+
adapter_name=adapter_name,
399+
adapter_state_dict=state_dict,
400+
peft_config=lora_config,
401+
**peft_kwargs,
402+
)
445403

446-
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
404+
# scale LoRA layers with `lora_scale`
405+
scale_lora_layers(text_encoder, weight=lora_scale)
447406

448-
# Offload back.
449-
if is_model_cpu_offload:
450-
_pipeline.enable_model_cpu_offload()
451-
elif is_sequential_cpu_offload:
452-
_pipeline.enable_sequential_cpu_offload()
407+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
453408
# Unsafe code />
454409

455-
if prefix is not None and not state_dict:
456-
logger.warning(
457-
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
458-
"This is safe to ignore if LoRA state dict didn't originally have any "
459-
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
460-
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
461-
"https://github.com/huggingface/diffusers/issues/new"
462-
)
410+
_maybe_warn_if_no_keys_found(state_dict, prefix, model_class_name=text_encoder.__class__.__name__)
463411

464412

465413
def _func_optionally_disable_offloading(_pipeline):

0 commit comments

Comments
 (0)