Skip to content

Commit ca8080d

Browse files
committed
factor out text encoder loading.
1 parent e2deb82 commit ca8080d

File tree

2 files changed

+174
-135
lines changed

2 files changed

+174
-135
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 156 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,20 @@
2828
from ..utils import (
2929
USE_PEFT_BACKEND,
3030
_get_model_file,
31+
convert_state_dict_to_diffusers,
32+
convert_state_dict_to_peft,
3133
delete_adapter_layers,
3234
deprecate,
35+
get_adapter_name,
36+
get_peft_kwargs,
3337
is_accelerate_available,
3438
is_peft_available,
39+
is_peft_version,
3540
is_transformers_available,
41+
is_transformers_version,
3642
logging,
3743
recurse_remove_peft_layers,
44+
scale_lora_layers,
3845
set_adapter_layers,
3946
set_weights_and_activate_adapters,
4047
)
@@ -43,6 +50,8 @@
4350
if is_transformers_available():
4451
from transformers import PreTrainedModel
4552

53+
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
54+
4655
if is_peft_available():
4756
from peft.tuners.tuners_utils import BaseTunerLayer
4857

@@ -297,6 +306,152 @@ def _best_guess_weight_name(
297306
return weight_name
298307

299308

309+
def _load_lora_into_text_encoder(
310+
state_dict,
311+
network_alphas,
312+
text_encoder,
313+
prefix=None,
314+
lora_scale=1.0,
315+
text_encoder_name="text_encoder",
316+
adapter_name=None,
317+
_pipeline=None,
318+
low_cpu_mem_usage=False,
319+
):
320+
if not USE_PEFT_BACKEND:
321+
raise ValueError("PEFT backend is required for this method.")
322+
323+
peft_kwargs = {}
324+
if low_cpu_mem_usage:
325+
if not is_peft_version(">=", "0.13.1"):
326+
raise ValueError(
327+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
328+
)
329+
if not is_transformers_version(">", "4.45.2"):
330+
# Note from sayakpaul: It's not in `transformers` stable yet.
331+
# https://github.com/huggingface/transformers/pull/33725/
332+
raise ValueError(
333+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
334+
)
335+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
336+
337+
from peft import LoraConfig
338+
339+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
340+
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
341+
# their prefixes.
342+
keys = list(state_dict.keys())
343+
prefix = text_encoder_name if prefix is None else prefix
344+
345+
# Safe prefix to check with.
346+
if any(text_encoder_name in key for key in keys):
347+
# Load the layers corresponding to text encoder and make necessary adjustments.
348+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
349+
text_encoder_lora_state_dict = {
350+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
351+
}
352+
353+
if len(text_encoder_lora_state_dict) > 0:
354+
logger.info(f"Loading {prefix}.")
355+
rank = {}
356+
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
357+
358+
# convert state dict
359+
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
360+
361+
for name, _ in text_encoder_attn_modules(text_encoder):
362+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
363+
rank_key = f"{name}.{module}.lora_B.weight"
364+
if rank_key not in text_encoder_lora_state_dict:
365+
continue
366+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
367+
368+
for name, _ in text_encoder_mlp_modules(text_encoder):
369+
for module in ("fc1", "fc2"):
370+
rank_key = f"{name}.{module}.lora_B.weight"
371+
if rank_key not in text_encoder_lora_state_dict:
372+
continue
373+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
374+
375+
if network_alphas is not None:
376+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
377+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
378+
379+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
380+
381+
if "use_dora" in lora_config_kwargs:
382+
if lora_config_kwargs["use_dora"]:
383+
if is_peft_version("<", "0.9.0"):
384+
raise ValueError(
385+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
386+
)
387+
else:
388+
if is_peft_version("<", "0.9.0"):
389+
lora_config_kwargs.pop("use_dora")
390+
391+
if "lora_bias" in lora_config_kwargs:
392+
if lora_config_kwargs["lora_bias"]:
393+
if is_peft_version("<=", "0.13.2"):
394+
raise ValueError(
395+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
396+
)
397+
else:
398+
if is_peft_version("<=", "0.13.2"):
399+
lora_config_kwargs.pop("lora_bias")
400+
401+
lora_config = LoraConfig(**lora_config_kwargs)
402+
403+
# adapter_name
404+
if adapter_name is None:
405+
adapter_name = get_adapter_name(text_encoder)
406+
407+
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
408+
409+
# inject LoRA layers and load the state dict
410+
# in transformers we automatically check whether the adapter name is already in use or not
411+
text_encoder.load_adapter(
412+
adapter_name=adapter_name,
413+
adapter_state_dict=text_encoder_lora_state_dict,
414+
peft_config=lora_config,
415+
**peft_kwargs,
416+
)
417+
418+
# scale LoRA layers with `lora_scale`
419+
scale_lora_layers(text_encoder, weight=lora_scale)
420+
421+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
422+
423+
# Offload back.
424+
if is_model_cpu_offload:
425+
_pipeline.enable_model_cpu_offload()
426+
elif is_sequential_cpu_offload:
427+
_pipeline.enable_sequential_cpu_offload()
428+
# Unsafe code />
429+
430+
431+
def _func_optionally_disable_offloading(_pipeline):
432+
is_model_cpu_offload = False
433+
is_sequential_cpu_offload = False
434+
435+
if _pipeline is not None and _pipeline.hf_device_map is None:
436+
for _, component in _pipeline.components.items():
437+
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
438+
if not is_model_cpu_offload:
439+
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
440+
if not is_sequential_cpu_offload:
441+
is_sequential_cpu_offload = (
442+
isinstance(component._hf_hook, AlignDevicesHook)
443+
or hasattr(component._hf_hook, "hooks")
444+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
445+
)
446+
447+
logger.info(
448+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
449+
)
450+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
451+
452+
return (is_model_cpu_offload, is_sequential_cpu_offload)
453+
454+
300455
class LoraBaseMixin:
301456
"""Utility class for handling LoRAs."""
302457

@@ -327,27 +482,7 @@ def _optionally_disable_offloading(cls, _pipeline):
327482
tuple:
328483
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
329484
"""
330-
is_model_cpu_offload = False
331-
is_sequential_cpu_offload = False
332-
333-
if _pipeline is not None and _pipeline.hf_device_map is None:
334-
for _, component in _pipeline.components.items():
335-
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
336-
if not is_model_cpu_offload:
337-
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
338-
if not is_sequential_cpu_offload:
339-
is_sequential_cpu_offload = (
340-
isinstance(component._hf_hook, AlignDevicesHook)
341-
or hasattr(component._hf_hook, "hooks")
342-
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
343-
)
344-
345-
logger.info(
346-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
347-
)
348-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
349-
350-
return (is_model_cpu_offload, is_sequential_cpu_offload)
485+
return _func_optionally_disable_offloading(_pipeline=_pipeline)
351486

352487
@classmethod
353488
def _fetch_state_dict(cls, *args, **kwargs):

src/diffusers/loaders/lora_pipeline.py

Lines changed: 18 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@
3333
logging,
3434
scale_lora_layers,
3535
)
36-
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
36+
from .lora_base import ( # noqa
37+
LORA_WEIGHT_NAME,
38+
LORA_WEIGHT_NAME_SAFE,
39+
LoraBaseMixin,
40+
_fetch_state_dict,
41+
_load_lora_into_text_encoder,
42+
)
3743
from .lora_conversion_utils import (
3844
_convert_bfl_flux_control_lora_to_diffusers,
3945
_convert_hunyuan_video_lora_to_diffusers,
@@ -349,119 +355,17 @@ def load_lora_into_text_encoder(
349355
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
350356
weights.
351357
"""
352-
if not USE_PEFT_BACKEND:
353-
raise ValueError("PEFT backend is required for this method.")
354-
355-
peft_kwargs = {}
356-
if low_cpu_mem_usage:
357-
if not is_peft_version(">=", "0.13.1"):
358-
raise ValueError(
359-
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
360-
)
361-
if not is_transformers_version(">", "4.45.2"):
362-
# Note from sayakpaul: It's not in `transformers` stable yet.
363-
# https://github.com/huggingface/transformers/pull/33725/
364-
raise ValueError(
365-
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
366-
)
367-
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
368-
369-
from peft import LoraConfig
370-
371-
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
372-
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
373-
# their prefixes.
374-
keys = list(state_dict.keys())
375-
prefix = cls.text_encoder_name if prefix is None else prefix
376-
377-
# Safe prefix to check with.
378-
if any(cls.text_encoder_name in key for key in keys):
379-
# Load the layers corresponding to text encoder and make necessary adjustments.
380-
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
381-
text_encoder_lora_state_dict = {
382-
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
383-
}
384-
385-
if len(text_encoder_lora_state_dict) > 0:
386-
logger.info(f"Loading {prefix}.")
387-
rank = {}
388-
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
389-
390-
# convert state dict
391-
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
392-
393-
for name, _ in text_encoder_attn_modules(text_encoder):
394-
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
395-
rank_key = f"{name}.{module}.lora_B.weight"
396-
if rank_key not in text_encoder_lora_state_dict:
397-
continue
398-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
399-
400-
for name, _ in text_encoder_mlp_modules(text_encoder):
401-
for module in ("fc1", "fc2"):
402-
rank_key = f"{name}.{module}.lora_B.weight"
403-
if rank_key not in text_encoder_lora_state_dict:
404-
continue
405-
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
406-
407-
if network_alphas is not None:
408-
alpha_keys = [
409-
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
410-
]
411-
network_alphas = {
412-
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
413-
}
414-
415-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
416-
417-
if "use_dora" in lora_config_kwargs:
418-
if lora_config_kwargs["use_dora"]:
419-
if is_peft_version("<", "0.9.0"):
420-
raise ValueError(
421-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
422-
)
423-
else:
424-
if is_peft_version("<", "0.9.0"):
425-
lora_config_kwargs.pop("use_dora")
426-
427-
if "lora_bias" in lora_config_kwargs:
428-
if lora_config_kwargs["lora_bias"]:
429-
if is_peft_version("<=", "0.13.2"):
430-
raise ValueError(
431-
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
432-
)
433-
else:
434-
if is_peft_version("<=", "0.13.2"):
435-
lora_config_kwargs.pop("lora_bias")
436-
437-
lora_config = LoraConfig(**lora_config_kwargs)
438-
439-
# adapter_name
440-
if adapter_name is None:
441-
adapter_name = get_adapter_name(text_encoder)
442-
443-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
444-
445-
# inject LoRA layers and load the state dict
446-
# in transformers we automatically check whether the adapter name is already in use or not
447-
text_encoder.load_adapter(
448-
adapter_name=adapter_name,
449-
adapter_state_dict=text_encoder_lora_state_dict,
450-
peft_config=lora_config,
451-
**peft_kwargs,
452-
)
453-
454-
# scale LoRA layers with `lora_scale`
455-
scale_lora_layers(text_encoder, weight=lora_scale)
456-
457-
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
458-
459-
# Offload back.
460-
if is_model_cpu_offload:
461-
_pipeline.enable_model_cpu_offload()
462-
elif is_sequential_cpu_offload:
463-
_pipeline.enable_sequential_cpu_offload()
464-
# Unsafe code />
358+
_load_lora_into_text_encoder(
359+
state_dict=state_dict,
360+
network_alphas=network_alphas,
361+
lora_scale=lora_scale,
362+
text_encoder=text_encoder,
363+
prefix=prefix,
364+
text_encoder_name=cls.text_encoder_name,
365+
adapter_name=adapter_name,
366+
_pipeline=_pipeline,
367+
low_cpu_mem_usage=low_cpu_mem_usage,
368+
)
465369

466370
@classmethod
467371
def save_lora_weights(

0 commit comments

Comments
 (0)