Skip to content

Commit 60a9db7

Browse files
committed
support more wan models
1 parent a98700f commit 60a9db7

File tree

8 files changed

+307
-42
lines changed

8 files changed

+307
-42
lines changed

diffsynth/configs/model_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from ..models.wan_video_text_encoder import WanTextEncoder
6060
from ..models.wan_video_image_encoder import WanImageEncoder
6161
from ..models.wan_video_vae import WanVideoVAE
62+
from ..models.wan_video_motion_controller import WanMotionControllerModel
6263

6364

6465
model_loader_configs = [
@@ -122,11 +123,13 @@
122123
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
123124
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
124125
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
126+
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
125127
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
126128
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
127129
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
128130
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
129131
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
132+
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
130133
]
131134
huggingface_model_loader_configs = [
132135
# These configs are provided for detecting model type automatically.

diffsynth/models/wan_video_dit.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,20 @@ def from_civitai(self, state_dict):
521521
"num_layers": 40,
522522
"eps": 1e-6
523523
}
524+
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
525+
config = {
526+
"has_image_input": True,
527+
"patch_size": [1, 2, 2],
528+
"in_dim": 48,
529+
"dim": 1536,
530+
"ffn_dim": 8960,
531+
"freq_dim": 256,
532+
"text_dim": 4096,
533+
"out_dim": 16,
534+
"num_heads": 12,
535+
"num_layers": 30,
536+
"eps": 1e-6
537+
}
524538
else:
525539
config = {}
526540
return state_dict, config
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import torch.nn as nn
3+
from .wan_video_dit import sinusoidal_embedding_1d
4+
5+
6+
7+
class WanMotionControllerModel(torch.nn.Module):
8+
def __init__(self, freq_dim=256, dim=1536):
9+
super().__init__()
10+
self.freq_dim = freq_dim
11+
self.linear = nn.Sequential(
12+
nn.Linear(freq_dim, dim),
13+
nn.SiLU(),
14+
nn.Linear(dim, dim),
15+
nn.SiLU(),
16+
nn.Linear(dim, dim * 6),
17+
)
18+
19+
def forward(self, motion_bucket_id):
20+
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
21+
emb = self.linear(emb)
22+
return emb
23+
24+
def init(self):
25+
state_dict = self.linear[-1].state_dict()
26+
state_dict = {i: state_dict[i] * 0 for i in state_dict}
27+
self.linear[-1].load_state_dict(state_dict)
28+
29+
@staticmethod
30+
def state_dict_converter():
31+
return WanMotionControllerModelDictConverter()
32+
33+
34+
35+
class WanMotionControllerModelDictConverter:
36+
def __init__(self):
37+
pass
38+
39+
def from_diffusers(self, state_dict):
40+
return state_dict
41+
42+
def from_civitai(self, state_dict):
43+
return state_dict
44+

diffsynth/pipelines/wan_video.py

Lines changed: 77 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
1919
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
2020
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
21+
from ..models.wan_video_motion_controller import WanMotionControllerModel
2122

2223

2324

@@ -31,7 +32,8 @@ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None
3132
self.image_encoder: WanImageEncoder = None
3233
self.dit: WanModel = None
3334
self.vae: WanVideoVAE = None
34-
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
35+
self.motion_controller: WanMotionControllerModel = None
36+
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller']
3537
self.height_division_factor = 16
3638
self.width_division_factor = 16
3739
self.use_unified_sequence_parallel = False
@@ -122,6 +124,22 @@ def enable_vram_management(self, num_persistent_param_in_dit=None):
122124
computation_device=self.device,
123125
),
124126
)
127+
if self.motion_controller is not None:
128+
dtype = next(iter(self.motion_controller.parameters())).dtype
129+
enable_vram_management(
130+
self.motion_controller,
131+
module_map = {
132+
torch.nn.Linear: AutoWrappedLinear,
133+
},
134+
module_config = dict(
135+
offload_dtype=dtype,
136+
offload_device="cpu",
137+
onload_dtype=dtype,
138+
onload_device="cpu",
139+
computation_dtype=dtype,
140+
computation_device=self.device,
141+
),
142+
)
125143
self.enable_cpu_offload()
126144

127145

@@ -134,6 +152,7 @@ def fetch_models(self, model_manager: ModelManager):
134152
self.dit = model_manager.fetch_model("wan_video_dit")
135153
self.vae = model_manager.fetch_model("wan_video_vae")
136154
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
155+
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
137156

138157

139158
@staticmethod
@@ -185,6 +204,25 @@ def encode_image(self, image, end_image, num_frames, height, width):
185204
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
186205
y = y.to(dtype=self.torch_dtype, device=self.device)
187206
return {"clip_feature": clip_context, "y": y}
207+
208+
209+
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
210+
control_video = self.preprocess_images(control_video)
211+
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
212+
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
213+
return latents
214+
215+
216+
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
217+
if control_video is not None:
218+
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
219+
if clip_feature is None or y is None:
220+
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
221+
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
222+
else:
223+
y = y[:, -16:]
224+
y = torch.concat([control_latents, y], dim=1)
225+
return {"clip_feature": clip_feature, "y": y}
188226

189227

190228
def tensor2video(self, frames):
@@ -210,6 +248,11 @@ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18,
210248

211249
def prepare_unified_sequence_parallel(self):
212250
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
251+
252+
253+
def prepare_motion_bucket_id(self, motion_bucket_id):
254+
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
255+
return {"motion_bucket_id": motion_bucket_id}
213256

214257

215258
@torch.no_grad()
@@ -220,6 +263,7 @@ def __call__(
220263
input_image=None,
221264
end_image=None,
222265
input_video=None,
266+
control_video=None,
223267
denoising_strength=1.0,
224268
seed=None,
225269
rand_device="cpu",
@@ -229,6 +273,7 @@ def __call__(
229273
cfg_scale=5.0,
230274
num_inference_steps=50,
231275
sigma_shift=5.0,
276+
motion_bucket_id=None,
232277
tiled=True,
233278
tile_size=(30, 52),
234279
tile_stride=(15, 26),
@@ -274,6 +319,17 @@ def __call__(
274319
else:
275320
image_emb = {}
276321

322+
# ControlNet
323+
if control_video is not None:
324+
self.load_models_to_device(["image_encoder", "vae"])
325+
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
326+
327+
# Motion Controller
328+
if self.motion_controller is not None and motion_bucket_id is not None:
329+
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
330+
else:
331+
motion_kwargs = {}
332+
277333
# Extra input
278334
extra_input = self.prepare_extra_input(latents)
279335

@@ -285,14 +341,24 @@ def __call__(
285341
usp_kwargs = self.prepare_unified_sequence_parallel()
286342

287343
# Denoise
288-
self.load_models_to_device(["dit"])
344+
self.load_models_to_device(["dit", "motion_controller"])
289345
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
290346
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
291347

292348
# Inference
293-
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
349+
noise_pred_posi = model_fn_wan_video(
350+
self.dit, motion_controller=self.motion_controller,
351+
x=latents, timestep=timestep,
352+
**prompt_emb_posi, **image_emb, **extra_input,
353+
**tea_cache_posi, **usp_kwargs, **motion_kwargs
354+
)
294355
if cfg_scale != 1.0:
295-
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs)
356+
noise_pred_nega = model_fn_wan_video(
357+
self.dit, motion_controller=self.motion_controller,
358+
x=latents, timestep=timestep,
359+
**prompt_emb_nega, **image_emb, **extra_input,
360+
**tea_cache_nega, **usp_kwargs, **motion_kwargs
361+
)
296362
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
297363
else:
298364
noise_pred = noise_pred_posi
@@ -365,13 +431,15 @@ def update(self, hidden_states):
365431

366432
def model_fn_wan_video(
367433
dit: WanModel,
368-
x: torch.Tensor,
369-
timestep: torch.Tensor,
370-
context: torch.Tensor,
434+
motion_controller: WanMotionControllerModel = None,
435+
x: torch.Tensor = None,
436+
timestep: torch.Tensor = None,
437+
context: torch.Tensor = None,
371438
clip_feature: Optional[torch.Tensor] = None,
372439
y: Optional[torch.Tensor] = None,
373440
tea_cache: TeaCache = None,
374441
use_unified_sequence_parallel: bool = False,
442+
motion_bucket_id: Optional[torch.Tensor] = None,
375443
**kwargs,
376444
):
377445
if use_unified_sequence_parallel:
@@ -382,6 +450,8 @@ def model_fn_wan_video(
382450

383451
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
384452
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
453+
if motion_bucket_id is not None and motion_controller is not None:
454+
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
385455
context = dit.text_embedding(context)
386456

387457
if dit.has_image_input:

examples/wanvideo/README.md

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,30 @@ cd DiffSynth-Studio
1010
pip install -e .
1111
```
1212

13-
Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority.
13+
## Model Zoo
1414

15-
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
16-
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
17-
* [Sage Attention](https://github.com/thu-ml/SageAttention)
18-
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
15+
|Developer|Name|Link|Scripts|
16+
|-|-|-|-|
17+
|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)|
18+
|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)|
19+
|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
20+
|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)|
21+
|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).|
22+
|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).|
23+
|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).|
24+
|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)|
25+
|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
26+
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
27+
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
28+
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
1929

20-
## Inference
30+
## VRAM Usage
2131

22-
### Wan-Video-1.3B-T2V
32+
* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
2333

24-
Wan-Video-1.3B-T2V supports text-to-video and video-to-video. See [`./wan_1.3b_text_to_video.py`](./wan_1.3b_text_to_video.py).
34+
* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!).
2535

26-
Required VRAM: 6G
27-
28-
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
29-
30-
Put sunglasses on the dog.
31-
32-
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
33-
34-
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
35-
36-
### Wan-Video-14B-T2V
37-
38-
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
39-
40-
We present a detailed table here. The model is tested on a single A100.
36+
We present a detailed table here. The model (14B text-to-video) is tested on a single A100.
4137

4238
|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting|
4339
|-|-|-|-|-|
@@ -47,31 +43,46 @@ We present a detailed table here. The model is tested on a single A100.
4743
|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes|
4844
|torch.float8_e4m3fn|0|24.0s/it|10G||
4945

50-
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
46+
**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
5147

52-
### Parallel Inference
48+
## Efficient Attention Implementation
5349

54-
1. Unified Sequence Parallel (USP)
50+
DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA.
5551

56-
```bash
57-
pip install xfuser>=0.4.3
58-
```
52+
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
53+
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
54+
* [Sage Attention](https://github.com/thu-ml/SageAttention)
55+
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.)
56+
57+
## Acceleration
58+
59+
We support multiple acceleration solutions:
60+
* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py).
61+
62+
* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py)
5963

6064
```bash
65+
pip install xfuser>=0.4.3
6166
torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py
6267
```
6368

64-
2. Tensor Parallel
69+
* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py).
70+
71+
## Gallery
72+
73+
1.3B text-to-video.
74+
75+
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
6576

66-
Tensor parallel module of Wan-Video-14B-T2V is still under development. An example script is provided in [`./wan_14b_text_to_video_tensor_parallel.py`](./wan_14b_text_to_video_tensor_parallel.py).
77+
Put sunglasses on the dog.
6778

68-
### Wan-Video-14B-I2V
79+
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
6980

70-
Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py).
81+
14B text-to-video.
7182

72-
**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.**
83+
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
7384

74-
![Image](https://github.com/user-attachments/assets/adf8047f-7943-4aaa-a555-2b32dc415f39)
85+
14B image-to-video.
7586

7687
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
7788

0 commit comments

Comments
 (0)