Skip to content

Commit 23ebbb4

Browse files
leffffasomozayiyixuxucbensimon
authored
Kandinsky 5 is finally in Diffusers! (huggingface#12478)
* add kandinsky5 transformer pipeline first version --------- Co-authored-by: Álvaro Somoza <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Charles <[email protected]>
1 parent 1b456bd commit 23ebbb4

File tree

13 files changed

+1957
-0
lines changed

13 files changed

+1957
-0
lines changed

docs/source/en/api/loaders/lora.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
107107

108108
[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin
109109

110+
## KandinskyLoraLoaderMixin
111+
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin
112+
110113
## LoraBaseMixin
111114

112115
[[autodoc]] loaders.lora_base.LoraBaseMixin

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@
220220
"HunyuanVideoTransformer3DModel",
221221
"I2VGenXLUNet",
222222
"Kandinsky3UNet",
223+
"Kandinsky5Transformer3DModel",
223224
"LatteTransformer3DModel",
224225
"LTXVideoTransformer3DModel",
225226
"Lumina2Transformer2DModel",
@@ -474,6 +475,7 @@
474475
"ImageTextPipelineOutput",
475476
"Kandinsky3Img2ImgPipeline",
476477
"Kandinsky3Pipeline",
478+
"Kandinsky5T2VPipeline",
477479
"KandinskyCombinedPipeline",
478480
"KandinskyImg2ImgCombinedPipeline",
479481
"KandinskyImg2ImgPipeline",
@@ -912,6 +914,7 @@
912914
HunyuanVideoTransformer3DModel,
913915
I2VGenXLUNet,
914916
Kandinsky3UNet,
917+
Kandinsky5Transformer3DModel,
915918
LatteTransformer3DModel,
916919
LTXVideoTransformer3DModel,
917920
Lumina2Transformer2DModel,
@@ -1136,6 +1139,7 @@
11361139
ImageTextPipelineOutput,
11371140
Kandinsky3Img2ImgPipeline,
11381141
Kandinsky3Pipeline,
1142+
Kandinsky5T2VPipeline,
11391143
KandinskyCombinedPipeline,
11401144
KandinskyImg2ImgCombinedPipeline,
11411145
KandinskyImg2ImgPipeline,

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder):
7777
"SanaLoraLoaderMixin",
7878
"Lumina2LoraLoaderMixin",
7979
"WanLoraLoaderMixin",
80+
"KandinskyLoraLoaderMixin",
8081
"HiDreamImageLoraLoaderMixin",
8182
"SkyReelsV2LoraLoaderMixin",
8283
"QwenImageLoraLoaderMixin",
@@ -115,6 +116,7 @@ def text_encoder_attn_modules(text_encoder):
115116
FluxLoraLoaderMixin,
116117
HiDreamImageLoraLoaderMixin,
117118
HunyuanVideoLoraLoaderMixin,
119+
KandinskyLoraLoaderMixin,
118120
LoraLoaderMixin,
119121
LTXVideoLoraLoaderMixin,
120122
Lumina2LoraLoaderMixin,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3639,6 +3639,291 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
36393639
super().unfuse_lora(components=components, **kwargs)
36403640

36413641

3642+
class KandinskyLoraLoaderMixin(LoraBaseMixin):
3643+
r"""
3644+
Load LoRA layers into [`Kandinsky5Transformer3DModel`],
3645+
"""
3646+
3647+
_lora_loadable_modules = ["transformer"]
3648+
transformer_name = TRANSFORMER_NAME
3649+
3650+
@classmethod
3651+
@validate_hf_hub_args
3652+
def lora_state_dict(
3653+
cls,
3654+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3655+
**kwargs,
3656+
):
3657+
r"""
3658+
Return state dict for lora weights and the network alphas.
3659+
3660+
Parameters:
3661+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3662+
Can be either:
3663+
- A string, the *model id* of a pretrained model hosted on the Hub.
3664+
- A path to a *directory* containing the model weights.
3665+
- A [torch state
3666+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
3667+
3668+
cache_dir (`Union[str, os.PathLike]`, *optional*):
3669+
Path to a directory where a downloaded pretrained model configuration is cached.
3670+
force_download (`bool`, *optional*, defaults to `False`):
3671+
Whether or not to force the (re-)download of the model weights.
3672+
proxies (`Dict[str, str]`, *optional*):
3673+
A dictionary of proxy servers to use by protocol or endpoint.
3674+
local_files_only (`bool`, *optional*, defaults to `False`):
3675+
Whether to only load local model weights and configuration files.
3676+
token (`str` or *bool*, *optional*):
3677+
The token to use as HTTP bearer authorization for remote files.
3678+
revision (`str`, *optional*, defaults to `"main"`):
3679+
The specific model version to use.
3680+
subfolder (`str`, *optional*, defaults to `""`):
3681+
The subfolder location of a model file within a larger model repository.
3682+
weight_name (`str`, *optional*, defaults to None):
3683+
Name of the serialized state dict file.
3684+
use_safetensors (`bool`, *optional*):
3685+
Whether to use safetensors for loading.
3686+
return_lora_metadata (`bool`, *optional*, defaults to False):
3687+
When enabled, additionally return the LoRA adapter metadata.
3688+
"""
3689+
# Load the main state dict first which has the LoRA layers
3690+
cache_dir = kwargs.pop("cache_dir", None)
3691+
force_download = kwargs.pop("force_download", False)
3692+
proxies = kwargs.pop("proxies", None)
3693+
local_files_only = kwargs.pop("local_files_only", None)
3694+
token = kwargs.pop("token", None)
3695+
revision = kwargs.pop("revision", None)
3696+
subfolder = kwargs.pop("subfolder", None)
3697+
weight_name = kwargs.pop("weight_name", None)
3698+
use_safetensors = kwargs.pop("use_safetensors", None)
3699+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
3700+
3701+
allow_pickle = False
3702+
if use_safetensors is None:
3703+
use_safetensors = True
3704+
allow_pickle = True
3705+
3706+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
3707+
3708+
state_dict, metadata = _fetch_state_dict(
3709+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
3710+
weight_name=weight_name,
3711+
use_safetensors=use_safetensors,
3712+
local_files_only=local_files_only,
3713+
cache_dir=cache_dir,
3714+
force_download=force_download,
3715+
proxies=proxies,
3716+
token=token,
3717+
revision=revision,
3718+
subfolder=subfolder,
3719+
user_agent=user_agent,
3720+
allow_pickle=allow_pickle,
3721+
)
3722+
3723+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
3724+
if is_dora_scale_present:
3725+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
3726+
logger.warning(warn_msg)
3727+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
3728+
3729+
out = (state_dict, metadata) if return_lora_metadata else state_dict
3730+
return out
3731+
3732+
def load_lora_weights(
3733+
self,
3734+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
3735+
adapter_name: Optional[str] = None,
3736+
hotswap: bool = False,
3737+
**kwargs,
3738+
):
3739+
"""
3740+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
3741+
3742+
Parameters:
3743+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
3744+
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
3745+
adapter_name (`str`, *optional*):
3746+
Adapter name to be used for referencing the loaded adapter model.
3747+
hotswap (`bool`, *optional*):
3748+
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
3749+
low_cpu_mem_usage (`bool`, *optional*):
3750+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
3751+
weights.
3752+
kwargs (`dict`, *optional*):
3753+
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
3754+
"""
3755+
if not USE_PEFT_BACKEND:
3756+
raise ValueError("PEFT backend is required for this method.")
3757+
3758+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
3759+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
3760+
raise ValueError(
3761+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3762+
)
3763+
3764+
# if a dict is passed, copy it instead of modifying it inplace
3765+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
3766+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
3767+
3768+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3769+
kwargs["return_lora_metadata"] = True
3770+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3771+
3772+
is_correct_format = all("lora" in key for key in state_dict.keys())
3773+
if not is_correct_format:
3774+
raise ValueError("Invalid LoRA checkpoint.")
3775+
3776+
# Load LoRA into transformer
3777+
self.load_lora_into_transformer(
3778+
state_dict,
3779+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
3780+
adapter_name=adapter_name,
3781+
metadata=metadata,
3782+
_pipeline=self,
3783+
low_cpu_mem_usage=low_cpu_mem_usage,
3784+
hotswap=hotswap,
3785+
)
3786+
3787+
@classmethod
3788+
def load_lora_into_transformer(
3789+
cls,
3790+
state_dict,
3791+
transformer,
3792+
adapter_name=None,
3793+
_pipeline=None,
3794+
low_cpu_mem_usage=False,
3795+
hotswap: bool = False,
3796+
metadata=None,
3797+
):
3798+
"""
3799+
Load the LoRA layers specified in `state_dict` into `transformer`.
3800+
3801+
Parameters:
3802+
state_dict (`dict`):
3803+
A standard state dict containing the lora layer parameters.
3804+
transformer (`Kandinsky5Transformer3DModel`):
3805+
The transformer model to load the LoRA layers into.
3806+
adapter_name (`str`, *optional*):
3807+
Adapter name to be used for referencing the loaded adapter model.
3808+
low_cpu_mem_usage (`bool`, *optional*):
3809+
Speed up model loading by only loading the pretrained LoRA weights.
3810+
hotswap (`bool`, *optional*):
3811+
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
3812+
metadata (`dict`):
3813+
Optional LoRA adapter metadata.
3814+
"""
3815+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
3816+
raise ValueError(
3817+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
3818+
)
3819+
3820+
# Load the layers corresponding to transformer.
3821+
logger.info(f"Loading {cls.transformer_name}.")
3822+
transformer.load_lora_adapter(
3823+
state_dict,
3824+
network_alphas=None,
3825+
adapter_name=adapter_name,
3826+
metadata=metadata,
3827+
_pipeline=_pipeline,
3828+
low_cpu_mem_usage=low_cpu_mem_usage,
3829+
hotswap=hotswap,
3830+
)
3831+
3832+
@classmethod
3833+
def save_lora_weights(
3834+
cls,
3835+
save_directory: Union[str, os.PathLike],
3836+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
3837+
is_main_process: bool = True,
3838+
weight_name: str = None,
3839+
save_function: Callable = None,
3840+
safe_serialization: bool = True,
3841+
transformer_lora_adapter_metadata=None,
3842+
):
3843+
r"""
3844+
Save the LoRA parameters corresponding to the transformer and text encoders.
3845+
3846+
Arguments:
3847+
save_directory (`str` or `os.PathLike`):
3848+
Directory to save LoRA parameters to.
3849+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
3850+
State dict of the LoRA layers corresponding to the `transformer`.
3851+
is_main_process (`bool`, *optional*, defaults to `True`):
3852+
Whether the process calling this is the main process.
3853+
save_function (`Callable`):
3854+
The function to use to save the state dictionary.
3855+
safe_serialization (`bool`, *optional*, defaults to `True`):
3856+
Whether to save the model using `safetensors` or the traditional PyTorch way.
3857+
transformer_lora_adapter_metadata:
3858+
LoRA adapter metadata associated with the transformer.
3859+
"""
3860+
lora_layers = {}
3861+
lora_metadata = {}
3862+
3863+
if transformer_lora_layers:
3864+
lora_layers[cls.transformer_name] = transformer_lora_layers
3865+
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
3866+
3867+
if not lora_layers:
3868+
raise ValueError("You must pass at least one of `transformer_lora_layers`")
3869+
3870+
cls._save_lora_weights(
3871+
save_directory=save_directory,
3872+
lora_layers=lora_layers,
3873+
lora_metadata=lora_metadata,
3874+
is_main_process=is_main_process,
3875+
weight_name=weight_name,
3876+
save_function=save_function,
3877+
safe_serialization=safe_serialization,
3878+
)
3879+
3880+
def fuse_lora(
3881+
self,
3882+
components: List[str] = ["transformer"],
3883+
lora_scale: float = 1.0,
3884+
safe_fusing: bool = False,
3885+
adapter_names: Optional[List[str]] = None,
3886+
**kwargs,
3887+
):
3888+
r"""
3889+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
3890+
3891+
Args:
3892+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
3893+
lora_scale (`float`, defaults to 1.0):
3894+
Controls how much to influence the outputs with the LoRA parameters.
3895+
safe_fusing (`bool`, defaults to `False`):
3896+
Whether to check fused weights for NaN values before fusing.
3897+
adapter_names (`List[str]`, *optional*):
3898+
Adapter names to be used for fusing.
3899+
3900+
Example:
3901+
```py
3902+
from diffusers import Kandinsky5T2VPipeline
3903+
3904+
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
3905+
pipeline.load_lora_weights("path/to/lora.safetensors")
3906+
pipeline.fuse_lora(lora_scale=0.7)
3907+
```
3908+
"""
3909+
super().fuse_lora(
3910+
components=components,
3911+
lora_scale=lora_scale,
3912+
safe_fusing=safe_fusing,
3913+
adapter_names=adapter_names,
3914+
**kwargs,
3915+
)
3916+
3917+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
3918+
r"""
3919+
Reverses the effect of [`pipe.fuse_lora()`].
3920+
3921+
Args:
3922+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
3923+
"""
3924+
super().unfuse_lora(components=components, **kwargs)
3925+
3926+
36423927
class WanLoraLoaderMixin(LoraBaseMixin):
36433928
r"""
36443929
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
9292
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
9393
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
94+
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
9495
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
9596
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
9697
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
@@ -182,6 +183,7 @@
182183
HunyuanDiT2DModel,
183184
HunyuanVideoFramepackTransformer3DModel,
184185
HunyuanVideoTransformer3DModel,
186+
Kandinsky5Transformer3DModel,
185187
LatteTransformer3DModel,
186188
LTXVideoTransformer3DModel,
187189
Lumina2Transformer2DModel,

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .transformer_hidream_image import HiDreamImageTransformer2DModel
2828
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
2929
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
30+
from .transformer_kandinsky import Kandinsky5Transformer3DModel
3031
from .transformer_ltx import LTXVideoTransformer3DModel
3132
from .transformer_lumina2 import Lumina2Transformer2DModel
3233
from .transformer_mochi import MochiTransformer3DModel

0 commit comments

Comments
 (0)