Skip to content

Commit 1953e57

Browse files
author
Gleb Sterkin
committed
PR review pt.2
1 parent 89668b9 commit 1953e57

File tree

12 files changed

+132
-424
lines changed

12 files changed

+132
-424
lines changed

video/wan2.1/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ Pass `--sampler euler` to use Euler sampling for step-distilled models:
9797
For text to video pipeline you can try [this 4 steps distilled model](https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors)
9898

9999
```shell
100-
wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors
100+
wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_t2v_14b_lightx2v_4step.safetensors
101101
```
102102

103103
```shell
@@ -150,5 +150,5 @@ Recommended thresholds (1.3B):
150150
|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_005.gif)|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_01.gif)|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_025.gif)|
151151

152152
# References
153-
1. [Original WAN 2.1 implemetation](https://github.com/Wan-Video/Wan2.1)
153+
1. [Original WAN 2.1 implementation](https://github.com/Wan-Video/Wan2.1)
154154
2. [LightX2V](https://github.com/ModelTC/LightX2V)

video/wan2.1/img2video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def quantization_predicate(name, m):
4949
)
5050
parser.add_argument(
5151
"--n-prompt",
52-
default="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
52+
default="Text, watermarks, blury image, JPEG artifacts",
5353
)
5454
parser.add_argument(
5555
"--teacache",

video/wan2.1/txt2video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright © 2025 Apple Inc.
1+
# Copyright © 2026 Apple Inc.
22

33
"""Generate videos from text using Wan2.1."""
44

@@ -48,7 +48,7 @@ def quantization_predicate(name, m):
4848
)
4949
parser.add_argument(
5050
"--n-prompt",
51-
default="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
51+
default="Text, watermarks, blury image, JPEG artifacts",
5252
)
5353
parser.add_argument(
5454
"--teacache",

video/wan2.1/wan/layers.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,6 @@ def _residual_gate(x, y, gate):
2222
return x + y * gate
2323

2424

25-
class WanRMSNorm(nn.Module):
26-
def __init__(self, dim: int, eps: float = 1e-5):
27-
super().__init__()
28-
self.eps = eps
29-
self.weight = mx.ones((dim,))
30-
31-
def __call__(self, x: mx.array) -> mx.array:
32-
return mx.fast.rms_norm(x, self.weight, self.eps)
33-
34-
3525
class WanSelfAttention(nn.Module):
3626
def __init__(
3727
self,
@@ -48,10 +38,10 @@ def __init__(
4838
self.qkv = nn.Linear(dim, dim * 3)
4939
self.o = nn.Linear(dim, dim)
5040

51-
self.norm_q = WanRMSNorm(dim, eps=eps)
52-
self.norm_k = WanRMSNorm(dim, eps=eps)
41+
self.norm_q = nn.RMSNorm(dim, eps=eps)
42+
self.norm_k = nn.RMSNorm(dim, eps=eps)
5343

54-
def _attend(self, x, grid_sizes, freqs):
44+
def _attend(self, x, grid_sizes):
5545
"""Compute self-attention. Returns attn output [B, n, L, d]."""
5646
B, L, _ = x.shape
5747
n, d = self.num_heads, self.head_dim
@@ -66,17 +56,17 @@ def _attend(self, x, grid_sizes, freqs):
6656
k = k.reshape(B, L, n, d)
6757
v = v.reshape(B, L, n, d)
6858

69-
q = rope_apply(q, grid_sizes, freqs)
70-
k = rope_apply(k, grid_sizes, freqs)
59+
q = rope_apply(q, grid_sizes, self.head_dim)
60+
k = rope_apply(k, grid_sizes, self.head_dim)
7161

7262
q = q.transpose(0, 2, 1, 3)
7363
k = k.transpose(0, 2, 1, 3)
7464
v = v.transpose(0, 2, 1, 3)
7565
return mx.fast.scaled_dot_product_attention(q, k, v, scale=self.head_dim**-0.5)
7666

77-
def __call__(self, x, grid_sizes, freqs):
67+
def __call__(self, x, grid_sizes):
7868
B, L, C = x.shape
79-
attn = self._attend(x, grid_sizes, freqs)
69+
attn = self._attend(x, grid_sizes)
8070
return self.o(attn.transpose(0, 2, 1, 3).reshape(B, L, C))
8171

8272

@@ -97,8 +87,8 @@ def __init__(
9787
self.kv = nn.Linear(dim, dim * 2)
9888
self.o = nn.Linear(dim, dim)
9989

100-
self.norm_q = WanRMSNorm(dim, eps=eps)
101-
self.norm_k = WanRMSNorm(dim, eps=eps)
90+
self.norm_q = nn.RMSNorm(dim, eps=eps)
91+
self.norm_k = nn.RMSNorm(dim, eps=eps)
10292

10393
def _attend(self, x, context, context_lens):
10494
"""Compute text cross-attention. Returns (q, attn_out) both [B, n, L, d]."""
@@ -147,7 +137,7 @@ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
147137
super().__init__(dim, num_heads, eps)
148138
self.k_img = nn.Linear(dim, dim)
149139
self.v_img = nn.Linear(dim, dim)
150-
self.norm_k_img = WanRMSNorm(dim, eps=eps)
140+
self.norm_k_img = nn.RMSNorm(dim, eps=eps)
151141

152142
def __call__(self, x, context, context_lens):
153143
img_ctx_len = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER
@@ -220,7 +210,6 @@ def __call__(
220210
x: mx.array,
221211
e: mx.array,
222212
grid_sizes: list,
223-
freqs: dict,
224213
context: mx.array,
225214
context_lens: Optional[mx.array],
226215
) -> mx.array:
@@ -230,7 +219,6 @@ def __call__(
230219
y = self.self_attn(
231220
mx.fast.layer_norm(x, e[0, 1], e[0, 0], self.eps),
232221
grid_sizes,
233-
freqs,
234222
)
235223
x = _residual_gate(x, y, e[:, 2])
236224

video/wan2.1/wan/model.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from einops import rearrange
1919

2020
from .layers import Head, WanAttentionBlock
21-
from .rope import precompute_rope_freqs
2221

2322

2423
@partial(mx.compile, shapeless=True)
@@ -101,19 +100,6 @@ def __init__(
101100
# Output head
102101
self.head = Head(dim, out_dim, patch_size, eps)
103102

104-
# Precompute RoPE frequencies (not saved in checkpoint)
105-
self._freqs = precompute_rope_freqs(
106-
max_frames=1024,
107-
max_height=1024,
108-
max_width=1024,
109-
head_dim=self.head_dim,
110-
theta=10000.0,
111-
)
112-
113-
@property
114-
def freqs(self):
115-
return self._freqs
116-
117103
def _embed_image(self, clip_fea: mx.array) -> mx.array:
118104
"""Project CLIP features through img_emb MLP."""
119105
x = self.img_emb_norm1(clip_fea)
@@ -205,7 +191,7 @@ def __call__(
205191
else:
206192
x_in = x
207193
for block in self.blocks:
208-
x = block(x, e, grid_sizes, self.freqs, context, context_lens)
194+
x = block(x, e, grid_sizes, context, context_lens)
209195
new_residual = x - x_in
210196

211197
# Output head
@@ -347,27 +333,3 @@ def _merge_qkv_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
347333
merged[key] = value
348334

349335
return merged
350-
351-
352-
def create_wan_model(model_size: str = "1.3B", **kwargs) -> WanModel:
353-
configs = {
354-
"1.3B": {
355-
"dim": 1536,
356-
"ffn_dim": 8960,
357-
"freq_dim": 256,
358-
"num_heads": 12,
359-
"num_layers": 30,
360-
},
361-
"14B": {
362-
"dim": 5120,
363-
"ffn_dim": 13824,
364-
"freq_dim": 256,
365-
"num_heads": 40,
366-
"num_layers": 40,
367-
},
368-
}
369-
if model_size not in configs:
370-
raise ValueError(f"Unknown model size: {model_size}")
371-
config = configs[model_size]
372-
config.update(kwargs)
373-
return WanModel(**config)

0 commit comments

Comments
 (0)