Skip to content

Commit a33c63a

Browse files
authored
Merge pull request #518 from modelscope/wan-fun
Wan fun
2 parents 71eee78 + 3cc9764 commit a33c63a

File tree

8 files changed

+383
-44
lines changed

8 files changed

+383
-44
lines changed

diffsynth/configs/model_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from ..models.wan_video_text_encoder import WanTextEncoder
6060
from ..models.wan_video_image_encoder import WanImageEncoder
6161
from ..models.wan_video_vae import WanVideoVAE
62+
from ..models.wan_video_motion_controller import WanMotionControllerModel
6263

6364

6465
model_loader_configs = [
@@ -120,11 +121,16 @@
120121
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
121122
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
122123
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
124+
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
125+
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
126+
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
127+
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
123128
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
124129
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
125130
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
126131
(None, "1378ea763357eea97acdef78e65d6d96", ["wan_video_vae"], [WanVideoVAE], "civitai"),
127132
(None, "ccc42284ea13e1ad04693284c7a09be6", ["wan_video_vae"], [WanVideoVAE], "civitai"),
133+
(None, "dbd5ec76bbf977983f972c151d545389", ["wan_video_motion_controller"], [WanMotionControllerModel], "civitai"),
128134
]
129135
huggingface_model_loader_configs = [
130136
# These configs are provided for detecting model type automatically.

diffsynth/models/wan_video_dit.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,62 @@ def from_civitai(self, state_dict):
493493
"num_layers": 40,
494494
"eps": 1e-6
495495
}
496+
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
497+
config = {
498+
"has_image_input": True,
499+
"patch_size": [1, 2, 2],
500+
"in_dim": 36,
501+
"dim": 1536,
502+
"ffn_dim": 8960,
503+
"freq_dim": 256,
504+
"text_dim": 4096,
505+
"out_dim": 16,
506+
"num_heads": 12,
507+
"num_layers": 30,
508+
"eps": 1e-6
509+
}
510+
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
511+
config = {
512+
"has_image_input": True,
513+
"patch_size": [1, 2, 2],
514+
"in_dim": 36,
515+
"dim": 5120,
516+
"ffn_dim": 13824,
517+
"freq_dim": 256,
518+
"text_dim": 4096,
519+
"out_dim": 16,
520+
"num_heads": 40,
521+
"num_layers": 40,
522+
"eps": 1e-6
523+
}
524+
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
525+
config = {
526+
"has_image_input": True,
527+
"patch_size": [1, 2, 2],
528+
"in_dim": 48,
529+
"dim": 1536,
530+
"ffn_dim": 8960,
531+
"freq_dim": 256,
532+
"text_dim": 4096,
533+
"out_dim": 16,
534+
"num_heads": 12,
535+
"num_layers": 30,
536+
"eps": 1e-6
537+
}
538+
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
539+
config = {
540+
"has_image_input": True,
541+
"patch_size": [1, 2, 2],
542+
"in_dim": 48,
543+
"dim": 5120,
544+
"ffn_dim": 13824,
545+
"freq_dim": 256,
546+
"text_dim": 4096,
547+
"out_dim": 16,
548+
"num_heads": 40,
549+
"num_layers": 40,
550+
"eps": 1e-6
551+
}
496552
else:
497553
config = {}
498554
return state_dict, config
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import torch.nn as nn
3+
from .wan_video_dit import sinusoidal_embedding_1d
4+
5+
6+
7+
class WanMotionControllerModel(torch.nn.Module):
8+
def __init__(self, freq_dim=256, dim=1536):
9+
super().__init__()
10+
self.freq_dim = freq_dim
11+
self.linear = nn.Sequential(
12+
nn.Linear(freq_dim, dim),
13+
nn.SiLU(),
14+
nn.Linear(dim, dim),
15+
nn.SiLU(),
16+
nn.Linear(dim, dim * 6),
17+
)
18+
19+
def forward(self, motion_bucket_id):
20+
emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10)
21+
emb = self.linear(emb)
22+
return emb
23+
24+
def init(self):
25+
state_dict = self.linear[-1].state_dict()
26+
state_dict = {i: state_dict[i] * 0 for i in state_dict}
27+
self.linear[-1].load_state_dict(state_dict)
28+
29+
@staticmethod
30+
def state_dict_converter():
31+
return WanMotionControllerModelDictConverter()
32+
33+
34+
35+
class WanMotionControllerModelDictConverter:
36+
def __init__(self):
37+
pass
38+
39+
def from_diffusers(self, state_dict):
40+
return state_dict
41+
42+
def from_civitai(self, state_dict):
43+
return state_dict
44+

diffsynth/pipelines/wan_video.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
1919
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
2020
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
21+
from ..models.wan_video_motion_controller import WanMotionControllerModel
2122

2223

2324

@@ -31,7 +32,8 @@ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None
3132
self.image_encoder: WanImageEncoder = None
3233
self.dit: WanModel = None
3334
self.vae: WanVideoVAE = None
34-
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder']
35+
self.motion_controller: WanMotionControllerModel = None
36+
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller']
3537
self.height_division_factor = 16
3638
self.width_division_factor = 16
3739
self.use_unified_sequence_parallel = False
@@ -122,6 +124,22 @@ def enable_vram_management(self, num_persistent_param_in_dit=None):
122124
computation_device=self.device,
123125
),
124126
)
127+
if self.motion_controller is not None:
128+
dtype = next(iter(self.motion_controller.parameters())).dtype
129+
enable_vram_management(
130+
self.motion_controller,
131+
module_map = {
132+
torch.nn.Linear: AutoWrappedLinear,
133+
},
134+
module_config = dict(
135+
offload_dtype=dtype,
136+
offload_device="cpu",
137+
onload_dtype=dtype,
138+
onload_device="cpu",
139+
computation_dtype=dtype,
140+
computation_device=self.device,
141+
),
142+
)
125143
self.enable_cpu_offload()
126144

127145

@@ -134,6 +152,7 @@ def fetch_models(self, model_manager: ModelManager):
134152
self.dit = model_manager.fetch_model("wan_video_dit")
135153
self.vae = model_manager.fetch_model("wan_video_vae")
136154
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
155+
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
137156

138157

139158
@staticmethod
@@ -163,22 +182,47 @@ def encode_prompt(self, prompt, positive=True):
163182
return {"context": prompt_emb}
164183

165184

166-
def encode_image(self, image, num_frames, height, width):
185+
def encode_image(self, image, end_image, num_frames, height, width):
167186
image = self.preprocess_image(image.resize((width, height))).to(self.device)
168187
clip_context = self.image_encoder.encode_image([image])
169188
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
170189
msk[:, 1:] = 0
190+
if end_image is not None:
191+
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
192+
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
193+
msk[:, -1:] = 1
194+
else:
195+
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
196+
171197
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
172198
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
173199
msk = msk.transpose(1, 2)[0]
174200

175-
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
176201
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
177202
y = torch.concat([msk, y])
178203
y = y.unsqueeze(0)
179204
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
180205
y = y.to(dtype=self.torch_dtype, device=self.device)
181206
return {"clip_feature": clip_context, "y": y}
207+
208+
209+
def encode_control_video(self, control_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
210+
control_video = self.preprocess_images(control_video)
211+
control_video = torch.stack(control_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
212+
latents = self.encode_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
213+
return latents
214+
215+
216+
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
217+
if control_video is not None:
218+
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
219+
if clip_feature is None or y is None:
220+
clip_feature = torch.zeros((1, 257, 1280), dtype=self.torch_dtype, device=self.device)
221+
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=self.torch_dtype, device=self.device)
222+
else:
223+
y = y[:, -16:]
224+
y = torch.concat([control_latents, y], dim=1)
225+
return {"clip_feature": clip_feature, "y": y}
182226

183227

184228
def tensor2video(self, frames):
@@ -204,6 +248,11 @@ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18,
204248

205249
def prepare_unified_sequence_parallel(self):
206250
return {"use_unified_sequence_parallel": self.use_unified_sequence_parallel}
251+
252+
253+
def prepare_motion_bucket_id(self, motion_bucket_id):
254+
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
255+
return {"motion_bucket_id": motion_bucket_id}
207256

208257

209258
@torch.no_grad()
@@ -212,7 +261,9 @@ def __call__(
212261
prompt,
213262
negative_prompt="",
214263
input_image=None,
264+
end_image=None,
215265
input_video=None,
266+
control_video=None,
216267
denoising_strength=1.0,
217268
seed=None,
218269
rand_device="cpu",
@@ -222,6 +273,7 @@ def __call__(
222273
cfg_scale=5.0,
223274
num_inference_steps=50,
224275
sigma_shift=5.0,
276+
motion_bucket_id=None,
225277
tiled=True,
226278
tile_size=(30, 52),
227279
tile_stride=(15, 26),
@@ -263,10 +315,21 @@ def __call__(
263315
# Encode image
264316
if input_image is not None and self.image_encoder is not None:
265317
self.load_models_to_device(["image_encoder", "vae"])
266-
image_emb = self.encode_image(input_image, num_frames, height, width)
318+
image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
267319
else:
268320
image_emb = {}
269321

322+
# ControlNet
323+
if control_video is not None:
324+
self.load_models_to_device(["image_encoder", "vae"])
325+
image_emb = self.prepare_controlnet_kwargs(control_video, num_frames, height, width, **image_emb, **tiler_kwargs)
326+
327+
# Motion Controller
328+
if self.motion_controller is not None and motion_bucket_id is not None:
329+
motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id)
330+
else:
331+
motion_kwargs = {}
332+
270333
# Extra input
271334
extra_input = self.prepare_extra_input(latents)
272335

@@ -278,14 +341,24 @@ def __call__(
278341
usp_kwargs = self.prepare_unified_sequence_parallel()
279342

280343
# Denoise
281-
self.load_models_to_device(["dit"])
344+
self.load_models_to_device(["dit", "motion_controller"])
282345
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
283346
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
284347

285348
# Inference
286-
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi, **usp_kwargs)
349+
noise_pred_posi = model_fn_wan_video(
350+
self.dit, motion_controller=self.motion_controller,
351+
x=latents, timestep=timestep,
352+
**prompt_emb_posi, **image_emb, **extra_input,
353+
**tea_cache_posi, **usp_kwargs, **motion_kwargs
354+
)
287355
if cfg_scale != 1.0:
288-
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega, **usp_kwargs)
356+
noise_pred_nega = model_fn_wan_video(
357+
self.dit, motion_controller=self.motion_controller,
358+
x=latents, timestep=timestep,
359+
**prompt_emb_nega, **image_emb, **extra_input,
360+
**tea_cache_nega, **usp_kwargs, **motion_kwargs
361+
)
289362
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
290363
else:
291364
noise_pred = noise_pred_posi
@@ -358,13 +431,15 @@ def update(self, hidden_states):
358431

359432
def model_fn_wan_video(
360433
dit: WanModel,
361-
x: torch.Tensor,
362-
timestep: torch.Tensor,
363-
context: torch.Tensor,
434+
motion_controller: WanMotionControllerModel = None,
435+
x: torch.Tensor = None,
436+
timestep: torch.Tensor = None,
437+
context: torch.Tensor = None,
364438
clip_feature: Optional[torch.Tensor] = None,
365439
y: Optional[torch.Tensor] = None,
366440
tea_cache: TeaCache = None,
367441
use_unified_sequence_parallel: bool = False,
442+
motion_bucket_id: Optional[torch.Tensor] = None,
368443
**kwargs,
369444
):
370445
if use_unified_sequence_parallel:
@@ -375,6 +450,8 @@ def model_fn_wan_video(
375450

376451
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
377452
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
453+
if motion_bucket_id is not None and motion_controller is not None:
454+
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
378455
context = dit.text_embedding(context)
379456

380457
if dit.has_image_input:

0 commit comments

Comments
 (0)