Skip to content

Commit 3ee899f

Browse files
authored
[LoRA] Support Wan (huggingface#10943)
* update * refactor image-to-video pipeline * update * fix copied from * use FP32LayerNorm
1 parent dcd77ce commit 3ee899f

File tree

9 files changed

+584
-85
lines changed

9 files changed

+584
-85
lines changed

src/diffusers/loaders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def text_encoder_attn_modules(text_encoder):
7474
"HunyuanVideoLoraLoaderMixin",
7575
"SanaLoraLoaderMixin",
7676
"Lumina2LoraLoaderMixin",
77+
"WanLoraLoaderMixin",
7778
]
7879
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
7980
_import_structure["ip_adapter"] = [
@@ -112,6 +113,7 @@ def text_encoder_attn_modules(text_encoder):
112113
SD3LoraLoaderMixin,
113114
StableDiffusionLoraLoaderMixin,
114115
StableDiffusionXLLoraLoaderMixin,
116+
WanLoraLoaderMixin,
115117
)
116118
from .single_file import FromSingleFileMixin
117119
from .textual_inversion import TextualInversionLoaderMixin

src/diffusers/loaders/lora_pipeline.py

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4115,6 +4115,311 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
41154115
super().unfuse_lora(components=components)
41164116

41174117

4118+
class WanLoraLoaderMixin(LoraBaseMixin):
4119+
r"""
4120+
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
4121+
"""
4122+
4123+
_lora_loadable_modules = ["transformer"]
4124+
transformer_name = TRANSFORMER_NAME
4125+
4126+
@classmethod
4127+
@validate_hf_hub_args
4128+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
4129+
def lora_state_dict(
4130+
cls,
4131+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
4132+
**kwargs,
4133+
):
4134+
r"""
4135+
Return state dict for lora weights and the network alphas.
4136+
4137+
<Tip warning={true}>
4138+
4139+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4140+
4141+
This function is experimental and might change in the future.
4142+
4143+
</Tip>
4144+
4145+
Parameters:
4146+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4147+
Can be either:
4148+
4149+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
4150+
the Hub.
4151+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
4152+
with [`ModelMixin.save_pretrained`].
4153+
- A [torch state
4154+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
4155+
4156+
cache_dir (`Union[str, os.PathLike]`, *optional*):
4157+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
4158+
is not used.
4159+
force_download (`bool`, *optional*, defaults to `False`):
4160+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
4161+
cached versions if they exist.
4162+
4163+
proxies (`Dict[str, str]`, *optional*):
4164+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
4165+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
4166+
local_files_only (`bool`, *optional*, defaults to `False`):
4167+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
4168+
won't be downloaded from the Hub.
4169+
token (`str` or *bool*, *optional*):
4170+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
4171+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
4172+
revision (`str`, *optional*, defaults to `"main"`):
4173+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
4174+
allowed by Git.
4175+
subfolder (`str`, *optional*, defaults to `""`):
4176+
The subfolder location of a model file within a larger model repository on the Hub or locally.
4177+
4178+
"""
4179+
# Load the main state dict first which has the LoRA layers for either of
4180+
# transformer and text encoder or both.
4181+
cache_dir = kwargs.pop("cache_dir", None)
4182+
force_download = kwargs.pop("force_download", False)
4183+
proxies = kwargs.pop("proxies", None)
4184+
local_files_only = kwargs.pop("local_files_only", None)
4185+
token = kwargs.pop("token", None)
4186+
revision = kwargs.pop("revision", None)
4187+
subfolder = kwargs.pop("subfolder", None)
4188+
weight_name = kwargs.pop("weight_name", None)
4189+
use_safetensors = kwargs.pop("use_safetensors", None)
4190+
4191+
allow_pickle = False
4192+
if use_safetensors is None:
4193+
use_safetensors = True
4194+
allow_pickle = True
4195+
4196+
user_agent = {
4197+
"file_type": "attn_procs_weights",
4198+
"framework": "pytorch",
4199+
}
4200+
4201+
state_dict = _fetch_state_dict(
4202+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
4203+
weight_name=weight_name,
4204+
use_safetensors=use_safetensors,
4205+
local_files_only=local_files_only,
4206+
cache_dir=cache_dir,
4207+
force_download=force_download,
4208+
proxies=proxies,
4209+
token=token,
4210+
revision=revision,
4211+
subfolder=subfolder,
4212+
user_agent=user_agent,
4213+
allow_pickle=allow_pickle,
4214+
)
4215+
4216+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
4217+
if is_dora_scale_present:
4218+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
4219+
logger.warning(warn_msg)
4220+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
4221+
4222+
return state_dict
4223+
4224+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
4225+
def load_lora_weights(
4226+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
4227+
):
4228+
"""
4229+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
4230+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
4231+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
4232+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
4233+
dict is loaded into `self.transformer`.
4234+
4235+
Parameters:
4236+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
4237+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4238+
adapter_name (`str`, *optional*):
4239+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4240+
`default_{i}` where i is the total number of adapters being loaded.
4241+
low_cpu_mem_usage (`bool`, *optional*):
4242+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4243+
weights.
4244+
kwargs (`dict`, *optional*):
4245+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
4246+
"""
4247+
if not USE_PEFT_BACKEND:
4248+
raise ValueError("PEFT backend is required for this method.")
4249+
4250+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
4251+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4252+
raise ValueError(
4253+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4254+
)
4255+
4256+
# if a dict is passed, copy it instead of modifying it inplace
4257+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
4258+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
4259+
4260+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4261+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4262+
4263+
is_correct_format = all("lora" in key for key in state_dict.keys())
4264+
if not is_correct_format:
4265+
raise ValueError("Invalid LoRA checkpoint.")
4266+
4267+
self.load_lora_into_transformer(
4268+
state_dict,
4269+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
4270+
adapter_name=adapter_name,
4271+
_pipeline=self,
4272+
low_cpu_mem_usage=low_cpu_mem_usage,
4273+
)
4274+
4275+
@classmethod
4276+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
4277+
def load_lora_into_transformer(
4278+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
4279+
):
4280+
"""
4281+
This will load the LoRA layers specified in `state_dict` into `transformer`.
4282+
4283+
Parameters:
4284+
state_dict (`dict`):
4285+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
4286+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
4287+
encoder lora layers.
4288+
transformer (`WanTransformer3DModel`):
4289+
The Transformer model to load the LoRA layers into.
4290+
adapter_name (`str`, *optional*):
4291+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
4292+
`default_{i}` where i is the total number of adapters being loaded.
4293+
low_cpu_mem_usage (`bool`, *optional*):
4294+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
4295+
weights.
4296+
"""
4297+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
4298+
raise ValueError(
4299+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
4300+
)
4301+
4302+
# Load the layers corresponding to transformer.
4303+
logger.info(f"Loading {cls.transformer_name}.")
4304+
transformer.load_lora_adapter(
4305+
state_dict,
4306+
network_alphas=None,
4307+
adapter_name=adapter_name,
4308+
_pipeline=_pipeline,
4309+
low_cpu_mem_usage=low_cpu_mem_usage,
4310+
)
4311+
4312+
@classmethod
4313+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
4314+
def save_lora_weights(
4315+
cls,
4316+
save_directory: Union[str, os.PathLike],
4317+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
4318+
is_main_process: bool = True,
4319+
weight_name: str = None,
4320+
save_function: Callable = None,
4321+
safe_serialization: bool = True,
4322+
):
4323+
r"""
4324+
Save the LoRA parameters corresponding to the UNet and text encoder.
4325+
4326+
Arguments:
4327+
save_directory (`str` or `os.PathLike`):
4328+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
4329+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
4330+
State dict of the LoRA layers corresponding to the `transformer`.
4331+
is_main_process (`bool`, *optional*, defaults to `True`):
4332+
Whether the process calling this is the main process or not. Useful during distributed training and you
4333+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
4334+
process to avoid race conditions.
4335+
save_function (`Callable`):
4336+
The function to use to save the state dictionary. Useful during distributed training when you need to
4337+
replace `torch.save` with another method. Can be configured with the environment variable
4338+
`DIFFUSERS_SAVE_MODE`.
4339+
safe_serialization (`bool`, *optional*, defaults to `True`):
4340+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
4341+
"""
4342+
state_dict = {}
4343+
4344+
if not transformer_lora_layers:
4345+
raise ValueError("You must pass `transformer_lora_layers`.")
4346+
4347+
if transformer_lora_layers:
4348+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
4349+
4350+
# Save the model
4351+
cls.write_lora_layers(
4352+
state_dict=state_dict,
4353+
save_directory=save_directory,
4354+
is_main_process=is_main_process,
4355+
weight_name=weight_name,
4356+
save_function=save_function,
4357+
safe_serialization=safe_serialization,
4358+
)
4359+
4360+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
4361+
def fuse_lora(
4362+
self,
4363+
components: List[str] = ["transformer"],
4364+
lora_scale: float = 1.0,
4365+
safe_fusing: bool = False,
4366+
adapter_names: Optional[List[str]] = None,
4367+
**kwargs,
4368+
):
4369+
r"""
4370+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
4371+
4372+
<Tip warning={true}>
4373+
4374+
This is an experimental API.
4375+
4376+
</Tip>
4377+
4378+
Args:
4379+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
4380+
lora_scale (`float`, defaults to 1.0):
4381+
Controls how much to influence the outputs with the LoRA parameters.
4382+
safe_fusing (`bool`, defaults to `False`):
4383+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
4384+
adapter_names (`List[str]`, *optional*):
4385+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
4386+
4387+
Example:
4388+
4389+
```py
4390+
from diffusers import DiffusionPipeline
4391+
import torch
4392+
4393+
pipeline = DiffusionPipeline.from_pretrained(
4394+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
4395+
).to("cuda")
4396+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
4397+
pipeline.fuse_lora(lora_scale=0.7)
4398+
```
4399+
"""
4400+
super().fuse_lora(
4401+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
4402+
)
4403+
4404+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
4405+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
4406+
r"""
4407+
Reverses the effect of
4408+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
4409+
4410+
<Tip warning={true}>
4411+
4412+
This is an experimental API.
4413+
4414+
</Tip>
4415+
4416+
Args:
4417+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
4418+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
4419+
"""
4420+
super().unfuse_lora(components=components)
4421+
4422+
41184423
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
41194424
def __init__(self, *args, **kwargs):
41204425
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
5454
"SanaTransformer2DModel": lambda model_cls, weights: weights,
5555
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
56+
"WanTransformer3DModel": lambda model_cls, weights: weights,
5657
}
5758

5859

0 commit comments

Comments
 (0)