Skip to content

Commit 5852547

Browse files
committed
feat: lora support for Lumina2.
1 parent b75b204 commit 5852547

File tree

6 files changed

+481
-4
lines changed

6 files changed

+481
-4
lines changed

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def text_encoder_attn_modules(text_encoder):
7373
"Mochi1LoraLoaderMixin",
7474
"HunyuanVideoLoraLoaderMixin",
7575
"SanaLoraLoaderMixin",
76+
"Lumina2LoraLoaderMixin",
7677
]
7778
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
7879
_import_structure["ip_adapter"] = [
@@ -105,6 +106,7 @@ def text_encoder_attn_modules(text_encoder):
105106
HunyuanVideoLoraLoaderMixin,
106107
LoraLoaderMixin,
107108
LTXVideoLoraLoaderMixin,
109+
Lumina2LoraLoaderMixin,
108110
Mochi1LoraLoaderMixin,
109111
SanaLoraLoaderMixin,
110112
SD3LoraLoaderMixin,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3805,6 +3805,315 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
38053805
super().unfuse_lora(components=components)
38063806

38073807

3808+
class Lumina2LoraLoaderMixin(LoraBaseMixin):
3809+
r"""
3810+
Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
3811+
"""
3812+
3813+
_lora_loadable_modules = ["transformer"]
3814+
transformer_name = TRANSFORMER_NAME
3815+
3816+
@classmethod
3817+
@validate_hf_hub_args
3818+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
3819+
def lora_state_dict(
3820+
cls,
3821+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3822+
**kwargs,
3823+
):
3824+
r"""
3825+
Return state dict for lora weights and the network alphas.
3826+
3827+
<Tip warning={true}>
3828+
3829+
We support loading original format HunyuanVideo LoRA checkpoints.
3830+
3831+
This function is experimental and might change in the future.
3832+
3833+
</Tip>
3834+
3835+
Parameters:
3836+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3837+
Can be either:
3838+
3839+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
3840+
the Hub.
3841+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
3842+
with [`ModelMixin.save_pretrained`].
3843+
- A [torch state
3844+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3845+
3846+
cache_dir (`Union[str, os.PathLike]`, *optional*):
3847+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
3848+
is not used.
3849+
force_download (`bool`, *optional*, defaults to `False`):
3850+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
3851+
cached versions if they exist.
3852+
3853+
proxies (`Dict[str, str]`, *optional*):
3854+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
3855+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
3856+
local_files_only (`bool`, *optional*, defaults to `False`):
3857+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
3858+
won't be downloaded from the Hub.
3859+
token (`str` or *bool*, *optional*):
3860+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
3861+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
3862+
revision (`str`, *optional*, defaults to `"main"`):
3863+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
3864+
allowed by Git.
3865+
subfolder (`str`, *optional*, defaults to `""`):
3866+
The subfolder location of a model file within a larger model repository on the Hub or locally.
3867+
3868+
"""
3869+
# Load the main state dict first which has the LoRA layers for either of
3870+
# transformer and text encoder or both.
3871+
cache_dir = kwargs.pop("cache_dir", None)
3872+
force_download = kwargs.pop("force_download", False)
3873+
proxies = kwargs.pop("proxies", None)
3874+
local_files_only = kwargs.pop("local_files_only", None)
3875+
token = kwargs.pop("token", None)
3876+
revision = kwargs.pop("revision", None)
3877+
subfolder = kwargs.pop("subfolder", None)
3878+
weight_name = kwargs.pop("weight_name", None)
3879+
use_safetensors = kwargs.pop("use_safetensors", None)
3880+
3881+
allow_pickle = False
3882+
if use_safetensors is None:
3883+
use_safetensors = True
3884+
allow_pickle = True
3885+
3886+
user_agent = {
3887+
"file_type": "attn_procs_weights",
3888+
"framework": "pytorch",
3889+
}
3890+
3891+
state_dict = _fetch_state_dict(
3892+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3893+
weight_name=weight_name,
3894+
use_safetensors=use_safetensors,
3895+
local_files_only=local_files_only,
3896+
cache_dir=cache_dir,
3897+
force_download=force_download,
3898+
proxies=proxies,
3899+
token=token,
3900+
revision=revision,
3901+
subfolder=subfolder,
3902+
user_agent=user_agent,
3903+
allow_pickle=allow_pickle,
3904+
)
3905+
3906+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3907+
if is_dora_scale_present:
3908+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3909+
logger.warning(warn_msg)
3910+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3911+
3912+
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
3913+
if is_original_hunyuan_video:
3914+
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
3915+
3916+
return state_dict
3917+
3918+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
3919+
def load_lora_weights(
3920+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
3921+
):
3922+
"""
3923+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
3924+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
3925+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
3926+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
3927+
dict is loaded into `self.transformer`.
3928+
3929+
Parameters:
3930+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3931+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3932+
adapter_name (`str`, *optional*):
3933+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3934+
`default_{i}` where i is the total number of adapters being loaded.
3935+
low_cpu_mem_usage (`bool`, *optional*):
3936+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3937+
weights.
3938+
kwargs (`dict`, *optional*):
3939+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
3940+
"""
3941+
if not USE_PEFT_BACKEND:
3942+
raise ValueError("PEFT backend is required for this method.")
3943+
3944+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3945+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3946+
raise ValueError(
3947+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3948+
)
3949+
3950+
# if a dict is passed, copy it instead of modifying it inplace
3951+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
3952+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3953+
3954+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3955+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3956+
3957+
is_correct_format = all("lora" in key for key in state_dict.keys())
3958+
if not is_correct_format:
3959+
raise ValueError("Invalid LoRA checkpoint.")
3960+
3961+
self.load_lora_into_transformer(
3962+
state_dict,
3963+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3964+
adapter_name=adapter_name,
3965+
_pipeline=self,
3966+
low_cpu_mem_usage=low_cpu_mem_usage,
3967+
)
3968+
3969+
@classmethod
3970+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
3971+
def load_lora_into_transformer(
3972+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
3973+
):
3974+
"""
3975+
This will load the LoRA layers specified in `state_dict` into `transformer`.
3976+
3977+
Parameters:
3978+
state_dict (`dict`):
3979+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
3980+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
3981+
encoder lora layers.
3982+
transformer (`HunyuanVideoTransformer3DModel`):
3983+
The Transformer model to load the LoRA layers into.
3984+
adapter_name (`str`, *optional*):
3985+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
3986+
`default_{i}` where i is the total number of adapters being loaded.
3987+
low_cpu_mem_usage (`bool`, *optional*):
3988+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3989+
weights.
3990+
"""
3991+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
3992+
raise ValueError(
3993+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3994+
)
3995+
3996+
# Load the layers corresponding to transformer.
3997+
logger.info(f"Loading {cls.transformer_name}.")
3998+
transformer.load_lora_adapter(
3999+
state_dict,
4000+
network_alphas=None,
4001+
adapter_name=adapter_name,
4002+
_pipeline=_pipeline,
4003+
low_cpu_mem_usage=low_cpu_mem_usage,
4004+
)
4005+
4006+
@classmethod
4007+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4008+
def save_lora_weights(
4009+
cls,
4010+
save_directory: Union[str, os.PathLike],
4011+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
4012+
is_main_process: bool = True,
4013+
weight_name: str = None,
4014+
save_function: Callable = None,
4015+
safe_serialization: bool = True,
4016+
):
4017+
r"""
4018+
Save the LoRA parameters corresponding to the UNet and text encoder.
4019+
4020+
Arguments:
4021+
save_directory (`str` or `os.PathLike`):
4022+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
4023+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4024+
State dict of the LoRA layers corresponding to the `transformer`.
4025+
is_main_process (`bool`, *optional*, defaults to `True`):
4026+
Whether the process calling this is the main process or not. Useful during distributed training and you
4027+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4028+
process to avoid race conditions.
4029+
save_function (`Callable`):
4030+
The function to use to save the state dictionary. Useful during distributed training when you need to
4031+
replace `torch.save` with another method. Can be configured with the environment variable
4032+
`DIFFUSERS_SAVE_MODE`.
4033+
safe_serialization (`bool`, *optional*, defaults to `True`):
4034+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4035+
"""
4036+
state_dict = {}
4037+
4038+
if not transformer_lora_layers:
4039+
raise ValueError("You must pass `transformer_lora_layers`.")
4040+
4041+
if transformer_lora_layers:
4042+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4043+
4044+
# Save the model
4045+
cls.write_lora_layers(
4046+
state_dict=state_dict,
4047+
save_directory=save_directory,
4048+
is_main_process=is_main_process,
4049+
weight_name=weight_name,
4050+
save_function=save_function,
4051+
safe_serialization=safe_serialization,
4052+
)
4053+
4054+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
4055+
def fuse_lora(
4056+
self,
4057+
components: List[str] = ["transformer"],
4058+
lora_scale: float = 1.0,
4059+
safe_fusing: bool = False,
4060+
adapter_names: Optional[List[str]] = None,
4061+
**kwargs,
4062+
):
4063+
r"""
4064+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4065+
4066+
<Tip warning={true}>
4067+
4068+
This is an experimental API.
4069+
4070+
</Tip>
4071+
4072+
Args:
4073+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4074+
lora_scale (`float`, defaults to 1.0):
4075+
Controls how much to influence the outputs with the LoRA parameters.
4076+
safe_fusing (`bool`, defaults to `False`):
4077+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4078+
adapter_names (`List[str]`, *optional*):
4079+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4080+
4081+
Example:
4082+
4083+
```py
4084+
from diffusers import DiffusionPipeline
4085+
import torch
4086+
4087+
pipeline = DiffusionPipeline.from_pretrained(
4088+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4089+
).to("cuda")
4090+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4091+
pipeline.fuse_lora(lora_scale=0.7)
4092+
```
4093+
"""
4094+
super().fuse_lora(
4095+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4096+
)
4097+
4098+
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
4099+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4100+
r"""
4101+
Reverses the effect of
4102+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4103+
4104+
<Tip warning={true}>
4105+
4106+
This is an experimental API.
4107+
4108+
</Tip>
4109+
4110+
Args:
4111+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4112+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4113+
"""
4114+
super().unfuse_lora(components=components)
4115+
4116+
38084117
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
38094118
def __init__(self, *args, **kwargs):
38104119
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
5353
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
5454
"SanaTransformer2DModel": lambda model_cls, weights: weights,
55+
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
5556
}
5657

5758

src/diffusers/models/transformers/transformer_lumina2.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import List, Optional, Tuple, Union
16+
from typing import Any, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919
import torch.nn as nn
@@ -22,7 +22,7 @@
2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import PeftAdapterMixin
2424
from ...loaders.single_file_model import FromOriginalModelMixin
25-
from ...utils import logging
25+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2626
from ..attention import LuminaFeedForward
2727
from ..attention_processor import Attention
2828
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
@@ -461,8 +461,24 @@ def forward(
461461
timestep: torch.Tensor,
462462
encoder_hidden_states: torch.Tensor,
463463
encoder_attention_mask: torch.Tensor,
464+
attention_kwargs: Optional[Dict[str, Any]] = None,
464465
return_dict: bool = True,
465466
) -> Union[torch.Tensor, Transformer2DModelOutput]:
467+
if attention_kwargs is not None:
468+
attention_kwargs = attention_kwargs.copy()
469+
lora_scale = attention_kwargs.pop("scale", 1.0)
470+
else:
471+
lora_scale = 1.0
472+
473+
if USE_PEFT_BACKEND:
474+
# weight the lora layers by setting `lora_scale` for each PEFT layer
475+
scale_lora_layers(self, lora_scale)
476+
else:
477+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
478+
logger.warning(
479+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
480+
)
481+
466482
# 1. Condition, positional & patch embedding
467483
batch_size, _, height, width = hidden_states.shape
468484

@@ -523,6 +539,10 @@ def forward(
523539
)
524540
output = torch.stack(output, dim=0)
525541

542+
if USE_PEFT_BACKEND:
543+
# remove `lora_scale` from each PEFT layer
544+
unscale_lora_layers(self, lora_scale)
545+
526546
if not return_dict:
527547
return (output,)
528548
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)