Skip to content

Commit 96f5d4e

Browse files
Fix Wan2.2 low noise model load LoRA bug (#188)
1 parent 88cd350 commit 96f5d4e

File tree

3 files changed

+81
-4
lines changed

3 files changed

+81
-4
lines changed

diffsynth_engine/pipelines/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import numpy as np
44
from einops import rearrange
5-
from typing import Dict, List, Tuple, Union
5+
from typing import Dict, List, Tuple, Union, Optional
66
from PIL import Image
77

88
from diffsynth_engine.configs import BaseConfig, BaseStateDicts, LoraConfig
@@ -70,7 +70,11 @@ def load_loras(
7070
lora_list: List[Tuple[str, Union[float, LoraConfig]]],
7171
fused: bool = True,
7272
save_original_weight: bool = False,
73+
lora_converter: Optional[LoRAStateDictConverter] = None,
7374
):
75+
if not lora_converter:
76+
lora_converter = self.lora_converter
77+
7478
for lora_path, lora_item in lora_list:
7579
if isinstance(lora_item, float):
7680
lora_scale = lora_item
@@ -86,7 +90,7 @@ def load_loras(
8690
self.apply_scheduler_config(scheduler_config)
8791
logger.info(f"Applied scheduler args from LoraConfig: {scheduler_config}")
8892

89-
lora_state_dict = self.lora_converter.convert(state_dict)
93+
lora_state_dict = lora_converter.convert(state_dict)
9094
for model_name, state_dict in lora_state_dict.items():
9195
model = getattr(self, model_name)
9296
lora_args = []

diffsynth_engine/pipelines/wan_video.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,14 @@ def convert(self, state_dict):
9595
return state_dict
9696

9797

98+
class WanLowNoiseLoRAConverter(WanLoRAConverter):
99+
def convert(self, state_dict):
100+
return {"dit2": super().convert(state_dict)["dit"]}
101+
102+
98103
class WanVideoPipeline(BasePipeline):
99104
lora_converter = WanLoRAConverter()
105+
low_noise_lora_converter = WanLowNoiseLoRAConverter()
100106

101107
def __init__(
102108
self,
@@ -133,7 +139,13 @@ def __init__(
133139
self.image_encoder = image_encoder
134140
self.model_names = ["text_encoder", "dit", "dit2", "vae", "image_encoder"]
135141

136-
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
142+
def load_loras(
143+
self,
144+
lora_list: List[Tuple[str, float]],
145+
fused: bool = True,
146+
save_original_weight: bool = False,
147+
lora_converter: Optional[WanLoRAConverter] = None
148+
):
137149
assert self.config.tp_degree is None or self.config.tp_degree == 1, (
138150
"load LoRA is not allowed when tensor parallel is enabled; "
139151
"set tp_degree=None or tp_degree=1 during pipeline initialization"
@@ -142,10 +154,20 @@ def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, sav
142154
"load fused LoRA is not allowed when fully sharded data parallel is enabled; "
143155
"either load LoRA with fused=False or set use_fsdp=False during pipeline initialization"
144156
)
145-
super().load_loras(lora_list, fused, save_original_weight)
157+
super().load_loras(lora_list, fused, save_original_weight, lora_converter)
158+
159+
def load_loras_low_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
160+
assert self.dit2 is not None, "low noise LoRA can only be applied to Wan2.2"
161+
self.load_loras(lora_list, fused, save_original_weight, self.low_noise_lora_converter)
162+
163+
def load_loras_high_noise(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
164+
assert self.dit2 is not None, "high noise LoRA can only be applied to Wan2.2"
165+
self.load_loras(lora_list, fused, save_original_weight)
146166

147167
def unload_loras(self):
148168
self.dit.unload_loras()
169+
if self.dit2 is not None:
170+
self.dit2.unload_loras()
149171
self.text_encoder.unload_loras()
150172

151173
def get_default_fps(self) -> int:

examples/wan_lora_low_noise.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import argparse
2+
3+
from diffsynth_engine import WanPipelineConfig
4+
from diffsynth_engine.pipelines import WanVideoPipeline
5+
from diffsynth_engine.utils.download import fetch_model
6+
from diffsynth_engine.utils.video import save_video
7+
8+
9+
if __name__ == "__main__":
10+
parser = argparse.ArgumentParser(description="Select the wan speech-to-video pipeline example to run.")
11+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on.")
12+
parser.add_argument("--parallelism", type=int, default=1, help="Number of parallel devices to use.")
13+
parser.add_argument("--lora_dir", type=str, default="", help="Directory for LoRA weights.")
14+
args = parser.parse_args()
15+
config = WanPipelineConfig.basic_config(
16+
model_path=fetch_model(
17+
"Wan-AI/Wan2.2-T2V-A14B",
18+
revision="bf16",
19+
path=[
20+
"high_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors",
21+
"high_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors",
22+
"high_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors",
23+
"high_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors",
24+
"high_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors",
25+
"high_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors",
26+
"low_noise_model/diffusion_pytorch_model-00001-of-00006-bf16.safetensors",
27+
"low_noise_model/diffusion_pytorch_model-00002-of-00006-bf16.safetensors",
28+
"low_noise_model/diffusion_pytorch_model-00003-of-00006-bf16.safetensors",
29+
"low_noise_model/diffusion_pytorch_model-00004-of-00006-bf16.safetensors",
30+
"low_noise_model/diffusion_pytorch_model-00005-of-00006-bf16.safetensors",
31+
"low_noise_model/diffusion_pytorch_model-00006-of-00006-bf16.safetensors",
32+
],
33+
),
34+
parallelism=args.parallelism,
35+
device=args.device,
36+
)
37+
pipe = WanVideoPipeline.from_pretrained(config)
38+
pipe.load_loras_high_noise([(f"{args.lora_dir}/wan22-style1-violetevergarden-16-sel-2-high-000100.safetensors", 1.0)], fused=False, save_original_weight=False)
39+
pipe.load_loras_low_noise([(f"{args.lora_dir}/wan22-style1-violetevergarden-16-sel-2-low-4-000060.safetensors", 1.0)], fused=False, save_original_weight=False)
40+
41+
video = pipe(
42+
prompt="白天,晴天光,侧光,硬光,暖色调,中近景,中心构图,一个银色短发少女戴着精致的皇冠,穿着华丽的长裙,站在阳光明媚的花园中。她面向镜头微笑,眼睛闪烁着光芒。阳光从侧面照来,照亮了她的银色短发和华丽的服饰,营造出一种温暖而高贵的氛围。微风轻拂,吹动着她裙摆上的蕾丝花边,增添了几分动感。背景是盛开的花朵和绿意盎然的植物,为画面增色不少。,anime style",
43+
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
44+
num_frames=81,
45+
width=480,
46+
height=832,
47+
seed=42,
48+
)
49+
save_video(video, "wan22_t2v_lora.mp4", fps=pipe.get_default_fps())
50+
51+
del pipe

0 commit comments

Comments
 (0)