Skip to content
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
e50a108
Add AuraFlowLoraLoaderMixin
Warlord-K Jul 30, 2024
658d058
Add comments, remove qkv fusion
Warlord-K Jul 30, 2024
4208d09
Add Tests
Warlord-K Jul 30, 2024
98b19f6
Add AuraFlowLoraLoaderMixin to documentation
Warlord-K Jul 30, 2024
71f8bac
Add Suggested changes
Warlord-K Aug 11, 2024
0eee03e
Change attention_kwargs->joint_attention_kwargs
Warlord-K Aug 12, 2024
4e4f780
Rebasing derp.
hameerabbasi Dec 13, 2024
c07d1f5
fix
hlky Dec 13, 2024
1b7f99f
fix
hlky Dec 13, 2024
875a3e0
Quality fixes.
hameerabbasi Dec 13, 2024
a242d7a
make style
hlky Dec 13, 2024
a73df6b
`make fix-copies`
hameerabbasi Dec 13, 2024
894eac0
`ruff check --fix`
hameerabbasi Dec 13, 2024
2b36416
Attept 1 to fix tests.
hameerabbasi Dec 15, 2024
6b762b8
Attept 2 to fix tests.
hameerabbasi Dec 15, 2024
bc2a466
Attept 3 to fix tests.
hameerabbasi Dec 15, 2024
1c79095
Address review comments.
hameerabbasi Dec 19, 2024
9454e84
Rebasing derp.
hameerabbasi Dec 19, 2024
5700e52
Merge branch 'main' into auraflow-lora
hameerabbasi Jan 3, 2025
6da81f8
Merge branch 'main' into auraflow-lora
sayakpaul Jan 6, 2025
28a4918
Get more tests passing by copying from Flux. Address review comments.
hameerabbasi Jan 7, 2025
d6028cd
`joint_attention_kwargs`->`attention_kwargs`
hameerabbasi Jan 7, 2025
6e899a3
Merge branch 'main' into auraflow-lora
hameerabbasi Jan 7, 2025
2d02c2c
Add `lora_scale` property for te LoRAs.
hameerabbasi Jan 7, 2025
2b934b4
Make test better.
hameerabbasi Jan 7, 2025
532013f
Remove useless property.
hameerabbasi Jan 7, 2025
0ea9ecd
Merge branch 'main' into auraflow-lora
hlky Jan 8, 2025
e06d8eb
Skip TE-only tests for AuraFlow.
hameerabbasi Jan 8, 2025
2b35909
Support LoRA for non-CLIP TEs.
hameerabbasi Jan 10, 2025
1ec07a1
Merge remote-tracking branch 'upstream/main' into auraflow-lora
hameerabbasi Jan 10, 2025
077a452
Merge branch 'main' into auraflow-lora
hlky Jan 10, 2025
3095644
Merge branch 'main' into auraflow-lora
hameerabbasi Jan 13, 2025
df28362
Merge remote-tracking branch 'upstream/main' into auraflow-lora
hameerabbasi Jan 19, 2025
7e63330
Restore LoRA tests.
hameerabbasi Jan 19, 2025
5620384
Undo adding LoRA support for non-CLIP TEs.
hameerabbasi Jan 19, 2025
cd691d3
Undo support for TE in AuraFlow LoRA.
hameerabbasi Jan 19, 2025
0fa5cd5
`make fix-copies`
hameerabbasi Jan 19, 2025
83e0825
Sync with upstream changes.
hameerabbasi Jan 19, 2025
12fbd11
Remove unneeded stuff.
hameerabbasi Jan 19, 2025
c602749
Merge branch 'main' into auraflow-lora
hameerabbasi Feb 26, 2025
cdd184d
Mirror `Lumina2`.
hameerabbasi Feb 26, 2025
ce1939b
Skip for MPS.
hameerabbasi Feb 26, 2025
3b9e655
Address review comments.
hameerabbasi Feb 26, 2025
c11b14d
Remove duplicated code.
hameerabbasi Feb 27, 2025
636f01c
Remove unnecessary code.
hameerabbasi Feb 27, 2025
75ba7da
Remove repeated docs.
hameerabbasi Mar 5, 2025
c2daa8a
Propagate attention.
hameerabbasi Mar 5, 2025
8aa2d69
Fix TE target modules.
hameerabbasi Mar 6, 2025
b19942f
MPS fix for LoRA tests.
hameerabbasi Mar 6, 2025
5091757
Unrelated TE LoRA tests fix.
hameerabbasi Mar 6, 2025
dee9074
Fix AuraFlow LoRA tests by applying to the right denoiser layers.
hameerabbasi Mar 26, 2025
6241109
Merge remote-tracking branch 'upstream/main' into auraflow-lora
hameerabbasi Mar 26, 2025
ed33194
Merge branch 'main' into auraflow-lora
sayakpaul Apr 8, 2025
65a3bf5
Apply style fixes
github-actions[bot] Apr 8, 2025
147a356
empty commit
sayakpaul Apr 8, 2025
0c91c1a
Fix the repo consistency issues.
hameerabbasi Apr 8, 2025
e97a83e
Remove unrelated changes.
hameerabbasi Apr 8, 2025
a5b78d1
Style.
hameerabbasi Apr 8, 2025
dbc8427
Fix `test_lora_fuse_nan`.
hameerabbasi Apr 8, 2025
22fc9d9
Merge branch 'main' into auraflow-lora
sayakpaul Apr 8, 2025
ea14465
fix quality issues.
sayakpaul Apr 8, 2025
a20d03d
`pytest.xfail` -> `ValueError`.
hameerabbasi Apr 8, 2025
fb5f5f7
Add back `skip_mps`.
hameerabbasi Apr 8, 2025
f88503b
Merge branch 'main' into auraflow-lora
sayakpaul Apr 8, 2025
12dc911
Apply style fixes
github-actions[bot] Apr 8, 2025
fd9ed52
Merge branch 'main' into auraflow-lora
sayakpaul Apr 9, 2025
e418c2f
Merge branch 'main' into auraflow-lora
sayakpaul Apr 10, 2025
5e537d1
Merge remote-tracking branch 'upstream/main' into auraflow-lora
hameerabbasi Apr 11, 2025
bc93160
`make fix-copies`
hameerabbasi Apr 11, 2025
2880ba4
Merge branch 'main' into auraflow-lora
sayakpaul Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
Expand Down Expand Up @@ -56,6 +57,9 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## Mochi1LoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
## AuraFlowLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin

## LTXVideoLoraLoaderMixin

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def text_encoder_attn_modules(text_encoder):
"AmusedLoraLoaderMixin",
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
Expand Down Expand Up @@ -103,6 +104,7 @@ def text_encoder_attn_modules(text_encoder):
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
Expand Down
309 changes: 309 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,315 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t
super().unfuse_lora(components=components, **kwargs)


class AuraFlowLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`].
"""

_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME

@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.

<Tip warning={true}>

We support loading A1111 formatted LoRA checkpoints in a limited capacity.

This function is experimental and might change in the future.

</Tip>

Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:

- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.

proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.

"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)

is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

return state_dict

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.

Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.

Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`AuraFlowTransformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.

Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}

if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")

if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))

# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)

# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.

<Tip warning={true}>

This is an experimental API.

</Tip>

Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.

Example:

```py
from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).

<Tip warning={true}>

This is an experimental API.

</Tip>

Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)


class FluxLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`FluxTransformer2DModel`],
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def text_encoder_attn_modules(text_encoder):
def text_encoder_attn_modules(text_encoder: nn.Module):
attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
Expand All @@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules


def text_encoder_mlp_modules(text_encoder):
def text_encoder_mlp_modules(text_encoder: nn.Module):
mlp_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
Expand Down
Loading
Loading