Skip to content

Commit e0712a2

Browse files
authored
add WanDMDPipeline (#219)
* add wan dmd pipeline update lora converter * add example
1 parent 70e0115 commit e0712a2

File tree

8 files changed

+233
-11
lines changed

8 files changed

+233
-11
lines changed

diffsynth_engine/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
SDXLImagePipeline,
2828
FluxImagePipeline,
2929
WanVideoPipeline,
30+
WanDMDPipeline,
3031
QwenImagePipeline,
3132
Hunyuan3DShapePipeline,
3233
)
@@ -81,6 +82,7 @@
8182
"FluxIPAdapter",
8283
"FluxRedux",
8384
"WanVideoPipeline",
85+
"WanDMDPipeline",
8486
"QwenImagePipeline",
8587
"Hunyuan3DShapePipeline",
8688
"FluxInpaintingTool",

diffsynth_engine/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .sd_image import SDImagePipeline
55
from .wan_video import WanVideoPipeline
66
from .wan_s2v import WanSpeech2VideoPipeline
7+
from .wan_dmd import WanDMDPipeline
78
from .qwen_image import QwenImagePipeline
89
from .hunyuan3d_shape import Hunyuan3DShapePipeline
910
from .z_image import ZImagePipeline
@@ -16,6 +17,7 @@
1617
"SDImagePipeline",
1718
"WanVideoPipeline",
1819
"WanSpeech2VideoPipeline",
20+
"WanDMDPipeline",
1921
"QwenImagePipeline",
2022
"Hunyuan3DShapePipeline",
2123
"ZImagePipeline",

diffsynth_engine/pipelines/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def load_lora(self, path: str, scale: float, fused: bool = True, save_original_w
145145
self.load_loras([(path, scale)], fused, save_original_weight)
146146

147147
def apply_scheduler_config(self, scheduler_config: Dict):
148-
pass
148+
self.noise_scheduler.update_config(scheduler_config)
149149

150150
def unload_loras(self):
151151
raise NotImplementedError()

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,6 @@ def unload_loras(self):
393393
self.dit.unload_loras()
394394
self.noise_scheduler.restore_config()
395395

396-
def apply_scheduler_config(self, scheduler_config: Dict):
397-
self.noise_scheduler.update_config(scheduler_config)
398-
399396
def prepare_latents(
400397
self,
401398
latents: torch.Tensor,
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import torch
2+
import torch.distributed as dist
3+
from typing import Callable, List, Optional
4+
from tqdm import tqdm
5+
from PIL import Image
6+
7+
from diffsynth_engine.pipelines.wan_video import WanVideoPipeline
8+
9+
10+
class WanDMDPipeline(WanVideoPipeline):
11+
def prepare_latents(
12+
self,
13+
latents,
14+
denoising_step_list,
15+
):
16+
height, width = latents.shape[-2:]
17+
height, width = height * self.upsampling_factor, width * self.upsampling_factor
18+
sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps=1000)
19+
sigmas = sigmas[[1000 - t for t in denoising_step_list] + [-1]]
20+
timesteps = timesteps[[1000 - t for t in denoising_step_list]]
21+
init_latents = latents.clone()
22+
23+
return init_latents, latents, sigmas, timesteps
24+
25+
@torch.no_grad()
26+
def __call__(
27+
self,
28+
prompt,
29+
input_image: Image.Image | None = None,
30+
seed=None,
31+
height=480,
32+
width=832,
33+
num_frames=81,
34+
denoising_step_list: List[int] = None,
35+
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
36+
):
37+
denoising_step_list = [1000, 750, 500, 250] if denoising_step_list is None else denoising_step_list
38+
divisor = 32 if self.vae.z_dim == 48 else 16 # 32 for wan2.2 vae, 16 for wan2.1 vae
39+
assert height % divisor == 0 and width % divisor == 0, f"height and width must be divisible by {divisor}"
40+
assert (num_frames - 1) % 4 == 0, "num_frames must be 4X+1"
41+
42+
# Initialize noise
43+
if dist.is_initialized() and seed is None:
44+
raise ValueError("must provide a seed when parallelism is enabled")
45+
noise = self.generate_noise(
46+
(
47+
1,
48+
self.vae.z_dim,
49+
(num_frames - 1) // 4 + 1,
50+
height // self.upsampling_factor,
51+
width // self.upsampling_factor,
52+
),
53+
seed=seed,
54+
device="cpu",
55+
dtype=torch.float32,
56+
).to(self.device)
57+
init_latents, latents, sigmas, timesteps = self.prepare_latents(noise, denoising_step_list)
58+
mask = torch.ones((1, 1, *latents.shape[2:]), dtype=latents.dtype, device=latents.device)
59+
60+
# Encode prompts
61+
self.load_models_to_device(["text_encoder"])
62+
prompt_emb_posi = self.encode_prompt(prompt)
63+
prompt_emb_nega = None
64+
65+
# Encode image
66+
image_clip_feature = self.encode_clip_feature(input_image, height, width)
67+
image_y = self.encode_vae_feature(input_image, num_frames, height, width)
68+
image_latents = self.encode_image_latents(input_image, height, width)
69+
if image_latents is not None:
70+
latents[:, :, : image_latents.shape[2], :, :] = image_latents
71+
init_latents = latents.clone()
72+
mask[:, :, : image_latents.shape[2], :, :] = 0
73+
74+
# Initialize sampler
75+
self.sampler.initialize(sigmas=sigmas)
76+
77+
# Denoise
78+
hide_progress = dist.is_initialized() and dist.get_rank() != 0
79+
for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
80+
if timestep.item() / 1000 >= self.config.boundary:
81+
self.load_models_to_device(["dit"])
82+
model = self.dit
83+
else:
84+
self.load_models_to_device(["dit2"])
85+
model = self.dit2
86+
87+
timestep = timestep * mask[:, :, :, ::2, ::2].flatten() # seq_len
88+
timestep = timestep.to(dtype=self.dtype, device=self.device)
89+
# Classifier-free guidance
90+
noise_pred = self.predict_noise_with_cfg(
91+
model=model,
92+
latents=latents,
93+
timestep=timestep,
94+
positive_prompt_emb=prompt_emb_posi,
95+
negative_prompt_emb=prompt_emb_nega,
96+
image_clip_feature=image_clip_feature,
97+
image_y=image_y,
98+
cfg_scale=1.0,
99+
batch_cfg=self.config.batch_cfg,
100+
)
101+
# Scheduler
102+
latents = self.sampler.step(latents, noise_pred, i)
103+
latents = latents * mask + init_latents * (1 - mask)
104+
if progress_callback is not None:
105+
progress_callback(i + 1, len(timesteps), "DENOISING")
106+
107+
# Decode
108+
self.load_models_to_device(["vae"])
109+
frames = self.decode_video(latents, progress_callback=progress_callback)
110+
frames = self.vae_output_to_image(frames)
111+
return frames

diffsynth_engine/pipelines/wan_video.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,24 @@ def _from_diffsynth(self, state_dict):
4343
dit_dict[key] = lora_args
4444
return {"dit": dit_dict}
4545

46+
def _from_diffusers(self, state_dict):
47+
dit_dict = {}
48+
for key, param in state_dict.items():
49+
if ".lora_down.weight" not in key:
50+
continue
51+
52+
lora_args = {}
53+
lora_args["up"] = state_dict[key.replace(".lora_down.weight", ".lora_up.weight")]
54+
lora_args["down"] = param
55+
lora_args["rank"] = lora_args["up"].shape[1]
56+
if key.replace(".lora_down.weight", ".alpha") in state_dict:
57+
lora_args["alpha"] = state_dict[key.replace(".lora_down.weight", ".alpha")]
58+
else:
59+
lora_args["alpha"] = lora_args["rank"]
60+
key = key.replace("diffusion_model.", "").replace(".lora_down.weight", "")
61+
dit_dict[key] = lora_args
62+
return {"dit": dit_dict}
63+
4664
def _from_civitai(self, state_dict):
4765
dit_dict = {}
4866
for key, param in state_dict.items():
@@ -86,6 +104,9 @@ def convert(self, state_dict):
86104
if "lora_unet_blocks_0_cross_attn_k.lora_down.weight" in state_dict:
87105
state_dict = self._from_fun(state_dict)
88106
logger.info("use fun format state dict")
107+
elif "diffusion_model.blocks.0.cross_attn.k.lora_down.weight" in state_dict:
108+
state_dict = self._from_diffusers(state_dict)
109+
logger.info("use diffusers format state dict")
89110
elif "diffusion_model.blocks.0.cross_attn.k.lora_A.weight" in state_dict:
90111
state_dict = self._from_civitai(state_dict)
91112
logger.info("use civitai format state dict")
@@ -480,8 +501,8 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi
480501

481502
dit_state_dict, dit2_state_dict = None, None
482503
if isinstance(config.model_path, list):
483-
high_noise_model_ckpt = [path for path in config.model_path if "high_noise_model" in path]
484-
low_noise_model_ckpt = [path for path in config.model_path if "low_noise_model" in path]
504+
high_noise_model_ckpt = [path for path in config.model_path if "high_noise" in path]
505+
low_noise_model_ckpt = [path for path in config.model_path if "low_noise" in path]
485506
if high_noise_model_ckpt and low_noise_model_ckpt:
486507
logger.info(f"loading high noise model state dict from {high_noise_model_ckpt} ...")
487508
dit_state_dict = cls.load_model_checkpoint(
@@ -681,8 +702,9 @@ def has_any_key(*xs):
681702
config.attn_params = VideoSparseAttentionParams(sparsity=0.9)
682703

683704
def update_weights(self, state_dicts: WanStateDicts) -> None:
684-
is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and
685-
("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model))
705+
is_dual_model_state_dict = isinstance(state_dicts.model, dict) and (
706+
"high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model
707+
)
686708
is_dual_model_pipeline = self.dit2 is not None
687709

688710
if is_dual_model_state_dict != is_dual_model_pipeline:
@@ -694,15 +716,21 @@ def update_weights(self, state_dicts: WanStateDicts) -> None:
694716

695717
if is_dual_model_state_dict:
696718
if "high_noise_model" in state_dicts.model:
697-
self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype)
719+
self.update_component(
720+
self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype
721+
)
698722
if "low_noise_model" in state_dicts.model:
699-
self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype)
723+
self.update_component(
724+
self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype
725+
)
700726
else:
701727
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
702728

703729
self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype)
704730
self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
705-
self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype)
731+
self.update_component(
732+
self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype
733+
)
706734

707735
def compile(self):
708736
self.dit.compile_repeated_blocks()

examples/wan_dmd_image_to_video.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from PIL import Image
2+
3+
from diffsynth_engine import WanPipelineConfig
4+
from diffsynth_engine.pipelines import WanDMDPipeline
5+
from diffsynth_engine.utils.download import fetch_model
6+
from diffsynth_engine.utils.video import save_video
7+
8+
9+
if __name__ == "__main__":
10+
config = WanPipelineConfig.basic_config(
11+
model_path=fetch_model(
12+
"lightx2v/Wan2.2-Distill-Models",
13+
path=[
14+
"wan2.2_i2v_A14b_high_noise_lightx2v_4step_1030.safetensors",
15+
"wan2.2_i2v_A14b_low_noise_lightx2v_4step.safetensors",
16+
],
17+
),
18+
parallelism=1,
19+
)
20+
pipe = WanDMDPipeline.from_pretrained(config)
21+
22+
image = Image.open("input/wan_i2v_input.jpg").convert("RGB")
23+
video = pipe(
24+
prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
25+
input_image=image,
26+
num_frames=81,
27+
width=480,
28+
height=832,
29+
seed=42,
30+
denoising_step_list=[1000, 750, 500, 250],
31+
)
32+
save_video(video, "wan_dmd_i2v.mp4", fps=pipe.get_default_fps())
33+
34+
del pipe

examples/wan_dmd_text_to_video.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from diffsynth_engine import WanPipelineConfig
2+
from diffsynth_engine.pipelines import WanDMDPipeline
3+
from diffsynth_engine.utils.download import fetch_model
4+
from diffsynth_engine.utils.video import save_video
5+
6+
7+
if __name__ == "__main__":
8+
config = WanPipelineConfig.basic_config(
9+
model_path=fetch_model(
10+
"Wan-AI/Wan2.2-T2V-A14B-BF16",
11+
path=[
12+
"high_noise_model/diffusion_pytorch_model-00001-of-00006.safetensors",
13+
"high_noise_model/diffusion_pytorch_model-00002-of-00006.safetensors",
14+
"high_noise_model/diffusion_pytorch_model-00003-of-00006.safetensors",
15+
"high_noise_model/diffusion_pytorch_model-00004-of-00006.safetensors",
16+
"high_noise_model/diffusion_pytorch_model-00005-of-00006.safetensors",
17+
"high_noise_model/diffusion_pytorch_model-00006-of-00006.safetensors",
18+
"low_noise_model/diffusion_pytorch_model-00001-of-00006.safetensors",
19+
"low_noise_model/diffusion_pytorch_model-00002-of-00006.safetensors",
20+
"low_noise_model/diffusion_pytorch_model-00003-of-00006.safetensors",
21+
"low_noise_model/diffusion_pytorch_model-00004-of-00006.safetensors",
22+
"low_noise_model/diffusion_pytorch_model-00005-of-00006.safetensors",
23+
"low_noise_model/diffusion_pytorch_model-00006-of-00006.safetensors",
24+
],
25+
),
26+
parallelism=1,
27+
)
28+
pipe = WanDMDPipeline.from_pretrained(config)
29+
pipe.load_loras_high_noise(
30+
[(fetch_model("lightx2v/Wan2.2-Lightning", path="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/high_noise_model.safetensors"), 1.0)],
31+
fused=False,
32+
)
33+
pipe.load_loras_low_noise(
34+
[(fetch_model("lightx2v/Wan2.2-Lightning", path="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V2.0/low_noise_model.safetensors"), 1.0)],
35+
fused=False,
36+
)
37+
38+
video = pipe(
39+
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
40+
num_frames=81,
41+
width=480,
42+
height=832,
43+
seed=42,
44+
denoising_step_list=[1000, 750, 500, 250],
45+
)
46+
save_video(video, "wan_dmd_t2v.mp4", fps=pipe.get_default_fps())
47+
48+
del pipe

0 commit comments

Comments
 (0)