Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffsynth/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from ..models.wan_video_vace import VaceWanModel
from ..models.wav2vec import WanS2VAudioEncoder
from ..models.wan_video_animate_adapter import WanAnimateAdapter
from ..models.wan_video_mot import MotWanModel

from ..models.step1x_connector import Qwen2Connector

Expand Down Expand Up @@ -157,6 +158,7 @@
(None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"),
(None, "5ec04e02b42d2580483ad69f4e76346a", ["wan_video_dit"], [WanModel], "civitai"),
(None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"),
(None, "5f90e66a0672219f12d9a626c8c21f61", ["wan_video_dit", "wan_video_vap"], [WanModel,MotWanModel], "diffusers"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
Expand Down
29 changes: 28 additions & 1 deletion diffsynth/models/wan_video_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ def from_diffusers(self, state_dict):
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
Expand All @@ -454,6 +459,14 @@ def from_diffusers(self, state_dict):
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
Expand All @@ -470,7 +483,7 @@ def from_diffusers(self, state_dict):
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param
if hash_state_dict_keys(state_dict) == "cb104773c6c2cb6df4f9529ad5c60d0b":
if hash_state_dict_keys(state_dict_) == "cb104773c6c2cb6df4f9529ad5c60d0b":
config = {
"model_type": "t2v",
"patch_size": (1, 2, 2),
Expand All @@ -488,6 +501,20 @@ def from_diffusers(self, state_dict):
"cross_attn_norm": True,
"eps": 1e-6,
}
elif hash_state_dict_keys(state_dict_) == "6bfcfb3b342cb286ce886889d519a77e":
config = {
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"out_dim": 16,
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6
}
else:
config = {}
return state_dict_, config
Expand Down
307 changes: 307 additions & 0 deletions diffsynth/models/wan_video_mot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
import torch
from .wan_video_dit import DiTBlock, SelfAttention, CrossAttention, rope_apply,flash_attention,modulate,MLP
from .utils import hash_state_dict_keys
import einops
import torch.nn as nn


class MotSelfAttention(SelfAttention):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__(dim, num_heads, eps)
def forward(self, x, freqs, is_before_attn=False):
if is_before_attn:
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
return q, k, v
else:
return self.o(x)


class MotWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
self.block_id = block_id

self.self_attn = MotSelfAttention(dim, num_heads, eps)


def forward(self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot):

# 1. prepare scale parameter
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
wan_block.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)

scale_params_mot_ref = self.modulation + t_mod_mot.float()
scale_params_mot_ref = einops.rearrange(scale_params_mot_ref, '(b n) t c -> b n t c', n=1)
shift_msa_mot_ref, scale_msa_mot_ref, gate_msa_mot_ref, c_shift_msa_mot_ref, c_scale_msa_mot_ref, c_gate_msa_mot_ref = scale_params_mot_ref.chunk(6, dim=2)

# 2. Self-attention
input_x = modulate(wan_block.norm1(x), shift_msa, scale_msa)
# original block self-attn
attn1 = wan_block.self_attn
q = attn1.norm_q(attn1.q(input_x))
k = attn1.norm_k(attn1.k(input_x))
v = attn1.v(input_x)
q = rope_apply(q, freqs, attn1.num_heads)
k = rope_apply(k, freqs, attn1.num_heads)

# mot block self-attn
norm_x_mot = einops.rearrange(self.norm1(x_mot.float()), 'b (n t) c -> b n t c', n=1)
norm_x_mot = modulate(norm_x_mot, shift_msa_mot_ref, scale_msa_mot_ref).type_as(x_mot)
norm_x_mot = einops.rearrange(norm_x_mot, 'b n t c -> b (n t) c', n=1)
q_mot,k_mot,v_mot = self.self_attn(norm_x_mot, freqs_mot, is_before_attn=True)

tmp_hidden_states = flash_attention(
torch.cat([q, q_mot], dim=-2),
torch.cat([k, k_mot], dim=-2),
torch.cat([v, v_mot], dim=-2),
num_heads=attn1.num_heads)

attn_output, attn_output_mot = torch.split(tmp_hidden_states, [q.shape[-2], q_mot.shape[-2]], dim=-2)

attn_output = attn1.o(attn_output)
x = wan_block.gate(x, gate_msa, attn_output)

attn_output_mot = self.self_attn(x=attn_output_mot,freqs=freqs_mot, is_before_attn=False)
# gate
attn_output_mot = einops.rearrange(attn_output_mot, 'b (n t) c -> b n t c', n=1)
attn_output_mot = attn_output_mot * gate_msa_mot_ref
attn_output_mot = einops.rearrange(attn_output_mot, 'b n t c -> b (n t) c', n=1)
x_mot = (x_mot.float() + attn_output_mot).type_as(x_mot)

# 3. cross-attention and feed-forward
x = x + wan_block.cross_attn(wan_block.norm3(x), context)
input_x = modulate(wan_block.norm2(x), shift_mlp, scale_mlp)
x = wan_block.gate(x, gate_mlp, wan_block.ffn(input_x))

x_mot = x_mot + self.cross_attn(self.norm3(x_mot),context_mot)
# modulate
norm_x_mot_ref = einops.rearrange(self.norm2(x_mot.float()), 'b (n t) c -> b n t c', n=1)
norm_x_mot_ref = (norm_x_mot_ref * (1 + c_scale_msa_mot_ref) + c_shift_msa_mot_ref).type_as(x_mot)
norm_x_mot_ref = einops.rearrange(norm_x_mot_ref, 'b n t c -> b (n t) c', n=1)
input_x_mot = self.ffn(norm_x_mot_ref)
# gate
input_x_mot = einops.rearrange(input_x_mot, 'b (n t) c -> b n t c', n=1)
input_x_mot = input_x_mot.float() * c_gate_msa_mot_ref
input_x_mot = einops.rearrange(input_x_mot, 'b n t c -> b (n t) c', n=1)
x_mot = (x_mot.float() + input_x_mot).type_as(x_mot)

return x, x_mot


class MotWanModel(torch.nn.Module):
def __init__(
self,
mot_layers=(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
patch_size=(1, 2, 2),
has_image_input=True,
has_image_pos_emb=False,
dim=5120,
num_heads=40,
ffn_dim=13824,
freq_dim=256,
text_dim=4096,
in_dim=36,
eps=1e-6,
):
super().__init__()
self.mot_layers = mot_layers
self.freq_dim = freq_dim
self.dim = dim

self.mot_layers_mapping = {i: n for n, i in enumerate(self.mot_layers)}
self.head_dim = dim // num_heads

self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)

self.text_embedding = nn.Sequential(
nn.Linear(text_dim, dim),
nn.GELU(approximate='tanh'),
nn.Linear(dim, dim)
)
self.time_embedding = nn.Sequential(
nn.Linear(freq_dim, dim),
nn.SiLU(),
nn.Linear(dim, dim)
)
self.time_projection = nn.Sequential(
nn.SiLU(), nn.Linear(dim, dim * 6))
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb)

# mot blocks
self.blocks = torch.nn.ModuleList([
MotWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
for i in self.mot_layers
])


def patchify(self, x: torch.Tensor):
x = self.patch_embedding(x)
return x

def compute_freqs_mot(self, f, h, w, end: int = 1024, theta: float = 10000.0):
def precompute_freqs_cis(dim: int, start: int = 0, end: int = 1024, theta: float = 10000.0):
# 1d rope precompute
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
[: (dim // 2)].double() / dim))
freqs = torch.outer(torch.arange(start, end, device=freqs.device), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

f_freqs_cis = precompute_freqs_cis(self.head_dim - 2 * (self.head_dim // 3), -f, end, theta)
h_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)
w_freqs_cis = precompute_freqs_cis(self.head_dim // 3, 0, end, theta)

freqs = torch.cat([
f_freqs_cis[:f].view(f, 1, 1, -1).expand(f, h, w, -1),
h_freqs_cis[:h].view(1, h, 1, -1).expand(f, h, w, -1),
w_freqs_cis[:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1)
return freqs

def forward(
self, wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot, block_id,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,

):

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward

block = self.blocks[self.mot_layers_mapping[block_id]]
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
x,x_mot = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot,
use_reentrant=False,
)
elif use_gradient_checkpointing:
x,x_mot = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot,
use_reentrant=False,
)
else:
x,x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot)

return x,x_mot

@staticmethod
def state_dict_converter():
return MotWanModelDictConverter()


class MotWanModelDictConverter:
def __init__(self):
pass

def from_diffusers(self, state_dict):

rename_dict = {
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
"blocks.0.scale_shift_table": "blocks.0.modulation",
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
"condition_embedder.time_proj.bias": "time_projection.1.bias",
"condition_embedder.time_proj.weight": "time_projection.1.weight",
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
"patch_embedding.bias": "patch_embedding.bias",
"patch_embedding.weight": "patch_embedding.weight",
"scale_shift_table": "head.modulation",
"proj_out.bias": "head.head.bias",
"proj_out.weight": "head.head.weight",
}
state_dict = {name: param for name, param in state_dict.items() if '_mot_ref' in name}
if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7':
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
else:
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
Comment on lines +267 to +270
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/else block is redundant because both branches assign the exact same tuple to mot_layers. You can simplify this to a single assignment.

        mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)

mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}

state_dict_ = {}

for name, param in state_dict.items():
name = name.replace("_mot_ref", "")
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
if name.split(".")[1].isdigit():
block_id = int(name.split(".")[1])
name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
if name_ in rename_dict:
name_ = rename_dict[name_]
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
state_dict_[name_] = param

if hash_state_dict_keys(state_dict_) == '6507c8213a3c476df5958b01dcf302d0': # vap 14B
config = {
"mot_layers":(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
"has_image_input": True,
"patch_size": [1, 2, 2],
"in_dim": 36,
"dim": 5120,
"ffn_dim": 13824,
"freq_dim": 256,
"text_dim": 4096,
"num_heads": 40,
"eps": 1e-6
}
else:
config = {}
return state_dict_, config



Loading