Skip to content

Commit e4cd847

Browse files
author
Gleb Sterkin
committed
PR review pt.2
1 parent 89668b9 commit e4cd847

File tree

12 files changed

+55
-368
lines changed

12 files changed

+55
-368
lines changed

video/wan2.1/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ Pass `--sampler euler` to use Euler sampling for step-distilled models:
9797
For text to video pipeline you can try [this 4 steps distilled model](https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors)
9898

9999
```shell
100-
wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/blob/main/wan2.1_t2v_14b_lightx2v_4step.safetensors
100+
wget https://huggingface.co/lightx2v/Wan2.1-Distill-Models/resolve/main/wan2.1_t2v_14b_lightx2v_4step.safetensors
101101
```
102102

103103
```shell
@@ -150,5 +150,5 @@ Recommended thresholds (1.3B):
150150
|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_005.gif)|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_01.gif)|![WAN t2v 1.3B teacache=0.05](static/out_t2v_1_3b_teacache_025.gif)|
151151

152152
# References
153-
1. [Original WAN 2.1 implemetation](https://github.com/Wan-Video/Wan2.1)
153+
1. [Original WAN 2.1 implementation](https://github.com/Wan-Video/Wan2.1)
154154
2. [LightX2V](https://github.com/ModelTC/LightX2V)

video/wan2.1/img2video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def quantization_predicate(name, m):
4949
)
5050
parser.add_argument(
5151
"--n-prompt",
52-
default="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
52+
default="Text, watermarks, blury image, JPEG artifacts",
5353
)
5454
parser.add_argument(
5555
"--teacache",

video/wan2.1/txt2video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright © 2025 Apple Inc.
1+
# Copyright © 2026 Apple Inc.
22

33
"""Generate videos from text using Wan2.1."""
44

@@ -48,7 +48,7 @@ def quantization_predicate(name, m):
4848
)
4949
parser.add_argument(
5050
"--n-prompt",
51-
default="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
51+
default="Text, watermarks, blury image, JPEG artifacts",
5252
)
5353
parser.add_argument(
5454
"--teacache",

video/wan2.1/wan/layers.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,6 @@ def _residual_gate(x, y, gate):
2222
return x + y * gate
2323

2424

25-
class WanRMSNorm(nn.Module):
26-
def __init__(self, dim: int, eps: float = 1e-5):
27-
super().__init__()
28-
self.eps = eps
29-
self.weight = mx.ones((dim,))
30-
31-
def __call__(self, x: mx.array) -> mx.array:
32-
return mx.fast.rms_norm(x, self.weight, self.eps)
33-
34-
3525
class WanSelfAttention(nn.Module):
3626
def __init__(
3727
self,
@@ -48,10 +38,10 @@ def __init__(
4838
self.qkv = nn.Linear(dim, dim * 3)
4939
self.o = nn.Linear(dim, dim)
5040

51-
self.norm_q = WanRMSNorm(dim, eps=eps)
52-
self.norm_k = WanRMSNorm(dim, eps=eps)
41+
self.norm_q = nn.RMSNorm(dim, eps=eps)
42+
self.norm_k = nn.RMSNorm(dim, eps=eps)
5343

54-
def _attend(self, x, grid_sizes, freqs):
44+
def _attend(self, x, grid_sizes):
5545
"""Compute self-attention. Returns attn output [B, n, L, d]."""
5646
B, L, _ = x.shape
5747
n, d = self.num_heads, self.head_dim
@@ -66,17 +56,17 @@ def _attend(self, x, grid_sizes, freqs):
6656
k = k.reshape(B, L, n, d)
6757
v = v.reshape(B, L, n, d)
6858

69-
q = rope_apply(q, grid_sizes, freqs)
70-
k = rope_apply(k, grid_sizes, freqs)
59+
q = rope_apply(q, grid_sizes, self.head_dim)
60+
k = rope_apply(k, grid_sizes, self.head_dim)
7161

7262
q = q.transpose(0, 2, 1, 3)
7363
k = k.transpose(0, 2, 1, 3)
7464
v = v.transpose(0, 2, 1, 3)
7565
return mx.fast.scaled_dot_product_attention(q, k, v, scale=self.head_dim**-0.5)
7666

77-
def __call__(self, x, grid_sizes, freqs):
67+
def __call__(self, x, grid_sizes):
7868
B, L, C = x.shape
79-
attn = self._attend(x, grid_sizes, freqs)
69+
attn = self._attend(x, grid_sizes)
8070
return self.o(attn.transpose(0, 2, 1, 3).reshape(B, L, C))
8171

8272

@@ -97,8 +87,8 @@ def __init__(
9787
self.kv = nn.Linear(dim, dim * 2)
9888
self.o = nn.Linear(dim, dim)
9989

100-
self.norm_q = WanRMSNorm(dim, eps=eps)
101-
self.norm_k = WanRMSNorm(dim, eps=eps)
90+
self.norm_q = nn.RMSNorm(dim, eps=eps)
91+
self.norm_k = nn.RMSNorm(dim, eps=eps)
10292

10393
def _attend(self, x, context, context_lens):
10494
"""Compute text cross-attention. Returns (q, attn_out) both [B, n, L, d]."""
@@ -147,7 +137,7 @@ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
147137
super().__init__(dim, num_heads, eps)
148138
self.k_img = nn.Linear(dim, dim)
149139
self.v_img = nn.Linear(dim, dim)
150-
self.norm_k_img = WanRMSNorm(dim, eps=eps)
140+
self.norm_k_img = nn.RMSNorm(dim, eps=eps)
151141

152142
def __call__(self, x, context, context_lens):
153143
img_ctx_len = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER
@@ -220,7 +210,6 @@ def __call__(
220210
x: mx.array,
221211
e: mx.array,
222212
grid_sizes: list,
223-
freqs: dict,
224213
context: mx.array,
225214
context_lens: Optional[mx.array],
226215
) -> mx.array:
@@ -230,7 +219,6 @@ def __call__(
230219
y = self.self_attn(
231220
mx.fast.layer_norm(x, e[0, 1], e[0, 0], self.eps),
232221
grid_sizes,
233-
freqs,
234222
)
235223
x = _residual_gate(x, y, e[:, 2])
236224

video/wan2.1/wan/model.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from einops import rearrange
1919

2020
from .layers import Head, WanAttentionBlock
21-
from .rope import precompute_rope_freqs
2221

2322

2423
@partial(mx.compile, shapeless=True)
@@ -101,19 +100,6 @@ def __init__(
101100
# Output head
102101
self.head = Head(dim, out_dim, patch_size, eps)
103102

104-
# Precompute RoPE frequencies (not saved in checkpoint)
105-
self._freqs = precompute_rope_freqs(
106-
max_frames=1024,
107-
max_height=1024,
108-
max_width=1024,
109-
head_dim=self.head_dim,
110-
theta=10000.0,
111-
)
112-
113-
@property
114-
def freqs(self):
115-
return self._freqs
116-
117103
def _embed_image(self, clip_fea: mx.array) -> mx.array:
118104
"""Project CLIP features through img_emb MLP."""
119105
x = self.img_emb_norm1(clip_fea)
@@ -205,7 +191,7 @@ def __call__(
205191
else:
206192
x_in = x
207193
for block in self.blocks:
208-
x = block(x, e, grid_sizes, self.freqs, context, context_lens)
194+
x = block(x, e, grid_sizes, context, context_lens)
209195
new_residual = x - x_in
210196

211197
# Output head
@@ -347,27 +333,3 @@ def _merge_qkv_weights(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
347333
merged[key] = value
348334

349335
return merged
350-
351-
352-
def create_wan_model(model_size: str = "1.3B", **kwargs) -> WanModel:
353-
configs = {
354-
"1.3B": {
355-
"dim": 1536,
356-
"ffn_dim": 8960,
357-
"freq_dim": 256,
358-
"num_heads": 12,
359-
"num_layers": 30,
360-
},
361-
"14B": {
362-
"dim": 5120,
363-
"ffn_dim": 13824,
364-
"freq_dim": 256,
365-
"num_heads": 40,
366-
"num_layers": 40,
367-
},
368-
}
369-
if model_size not in configs:
370-
raise ValueError(f"Unknown model size: {model_size}")
371-
config = configs[model_size]
372-
config.update(kwargs)
373-
return WanModel(**config)

video/wan2.1/wan/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ def generate_latents(
365365
else:
366366
noise_pred = noise_cond
367367

368-
# Scheduler step
368+
# Scheduler step — async_eval starts GPU work before yielding
369+
# so the caller's mx.eval(x_t) blocks for less time.
369370
x_t = sampler.step(noise_pred, t, x_t)
370371
mx.async_eval(x_t)
371372
yield x_t

video/wan2.1/wan/rope.py

Lines changed: 5 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Uses mx.fast.rope for optimized Metal kernel.
88
"""
99

10-
from functools import partial
1110
from typing import Tuple
1211

1312
import mlx.core as mx
@@ -29,71 +28,6 @@ def get_rope_dimensions(head_dim: int) -> Tuple[int, int, int]:
2928
return frame_dim, height_dim, width_dim
3029

3130

32-
def precompute_rope_freqs(
33-
max_frames: int,
34-
max_height: int,
35-
max_width: int,
36-
head_dim: int,
37-
theta: float = 10000.0,
38-
) -> dict:
39-
"""
40-
Precompute RoPE frequencies for 3D positions.
41-
42-
Each axis gets its own frequency computation with its own dimension.
43-
"""
44-
frame_dim, height_dim, width_dim = get_rope_dimensions(head_dim)
45-
46-
dim_frame = frame_dim // 2
47-
dim_height = height_dim // 2
48-
dim_width = width_dim // 2
49-
50-
frame_inv_freq = 1.0 / (
51-
theta ** (mx.arange(0, frame_dim, 2, dtype=mx.float32) / frame_dim)
52-
)
53-
height_inv_freq = 1.0 / (
54-
theta ** (mx.arange(0, height_dim, 2, dtype=mx.float32) / height_dim)
55-
)
56-
width_inv_freq = 1.0 / (
57-
theta ** (mx.arange(0, width_dim, 2, dtype=mx.float32) / width_dim)
58-
)
59-
60-
frame_positions = mx.arange(max_frames, dtype=mx.float32)
61-
height_positions = mx.arange(max_height, dtype=mx.float32)
62-
width_positions = mx.arange(max_width, dtype=mx.float32)
63-
64-
frame_freqs = frame_positions[:, None] * frame_inv_freq[None, :]
65-
frame_cos, frame_sin = mx.cos(frame_freqs), mx.sin(frame_freqs)
66-
67-
height_freqs = height_positions[:, None] * height_inv_freq[None, :]
68-
height_cos, height_sin = mx.cos(height_freqs), mx.sin(height_freqs)
69-
70-
width_freqs = width_positions[:, None] * width_inv_freq[None, :]
71-
width_cos, width_sin = mx.cos(width_freqs), mx.sin(width_freqs)
72-
73-
return {
74-
"frame": {
75-
"cos": frame_cos,
76-
"sin": frame_sin,
77-
"dim": dim_frame,
78-
"full_dim": frame_dim,
79-
},
80-
"height": {
81-
"cos": height_cos,
82-
"sin": height_sin,
83-
"dim": dim_height,
84-
"full_dim": height_dim,
85-
},
86-
"width": {
87-
"cos": width_cos,
88-
"sin": width_sin,
89-
"dim": dim_width,
90-
"full_dim": width_dim,
91-
},
92-
"theta": theta,
93-
"head_dim": head_dim,
94-
}
95-
96-
9731
@mx.compile
9832
def _rope_3d(x, f, h, w, frame_dim, height_dim, width_dim, theta):
9933
B = x.shape[0]
@@ -129,24 +63,23 @@ def _rope_3d(x, f, h, w, frame_dim, height_dim, width_dim, theta):
12963
def rope_apply(
13064
x: mx.array,
13165
grid_sizes: list,
132-
freqs: dict,
66+
head_dim: int,
67+
theta: float = 10000.0,
13368
) -> mx.array:
13469
"""
13570
Apply 3D RoPE using mx.fast.rope with reshapes.
13671
13772
Args:
13873
x: Tensor of shape [B, L, H, D]
13974
grid_sizes: List of [frames, height, width] per batch element
140-
freqs: Precomputed frequencies from precompute_rope_freqs()
75+
head_dim: Dimension per attention head
76+
theta: RoPE base frequency
14177
14278
Returns:
14379
Rotated tensor with same shape as x
14480
"""
14581
f, h, w = grid_sizes[0]
14682

147-
theta = freqs["theta"]
148-
frame_dim = freqs["frame"]["full_dim"]
149-
height_dim = freqs["height"]["full_dim"]
150-
width_dim = freqs["width"]["full_dim"]
83+
frame_dim, height_dim, width_dim = get_rope_dimensions(head_dim)
15184

15285
return _rope_3d(x, f, h, w, frame_dim, height_dim, width_dim, theta)

video/wan2.1/wan/sampler.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ def __init__(
2929
solver_order: int = 2,
3030
prediction_type: str = "flow_prediction",
3131
shift: Optional[float] = 1.0,
32-
thresholding: bool = False,
33-
dynamic_thresholding_ratio: float = 0.995,
34-
sample_max_value: float = 1.0,
3532
predict_x0: bool = True,
3633
solver_type: str = "bh2",
3734
lower_order_final: bool = True,
@@ -48,9 +45,6 @@ def __init__(
4845
self.solver_order = solver_order
4946
self.prediction_type = prediction_type
5047
self.shift = shift
51-
self.thresholding = thresholding
52-
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
53-
self.sample_max_value = sample_max_value
5448
self.predict_x0 = predict_x0
5549
self.solver_type = solver_type
5650
self.lower_order_final = lower_order_final
@@ -111,35 +105,11 @@ def _sigma_to_alpha_sigma_t(self, sigma):
111105
return 1 - sigma, sigma
112106

113107
def convert_model_output(self, model_output, sample):
114-
sigma = self.sigmas[self.step_index]
115-
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
116-
108+
sigma_t = self.sigmas[self.step_index]
117109
if self.predict_x0:
118-
sigma_t = self.sigmas[self.step_index]
119-
x0_pred = sample - sigma_t * model_output
120-
if self.thresholding:
121-
x0_pred = self._threshold_sample(x0_pred)
122-
return x0_pred
110+
return sample - sigma_t * model_output
123111
else:
124-
sigma_t = self.sigmas[self.step_index]
125-
epsilon = sample - (1 - sigma_t) * model_output
126-
return epsilon
127-
128-
def _threshold_sample(self, sample):
129-
dtype = sample.dtype
130-
batch_size, channels, *remaining_dims = sample.shape
131-
num_elements = 1
132-
for d in remaining_dims:
133-
num_elements *= d
134-
sample = sample.reshape(batch_size, channels * num_elements)
135-
abs_sample = mx.abs(sample)
136-
sorted_abs = mx.sort(abs_sample, axis=1)
137-
quantile_idx = int(self.dynamic_thresholding_ratio * abs_sample.shape[1])
138-
s = sorted_abs[:, quantile_idx : quantile_idx + 1]
139-
s = mx.clip(s, 1.0, self.sample_max_value)
140-
sample = mx.clip(sample, -s, s) / s
141-
sample = sample.reshape(batch_size, channels, *remaining_dims)
142-
return sample.astype(dtype)
112+
return sample - (1 - sigma_t) * model_output
143113

144114
def multistep_uni_p_bh_update(self, model_output, sample, order):
145115
model_output_list = self.model_outputs
@@ -197,6 +167,8 @@ def multistep_uni_p_bh_update(self, model_output, sample, order):
197167
if order == 2:
198168
rhos_p = mx.array([0.5], dtype=x.dtype)
199169
else:
170+
# Run on CPU for numerical stability (float64 not supported on Metal GPU),
171+
# matching the reference implementation.
200172
with mx.stream(mx.cpu):
201173
rhos_p = mx.linalg.solve(R[:-1, :-1], b[:-1]).astype(x.dtype)
202174
else:
@@ -286,6 +258,8 @@ def multistep_uni_c_bh_update(
286258
if order == 1:
287259
rhos_c = mx.array([0.5], dtype=x.dtype)
288260
else:
261+
# Run on CPU for numerical stability (float64 not supported on Metal GPU),
262+
# matching the reference implementation.
289263
with mx.stream(mx.cpu):
290264
rhos_c = mx.linalg.solve(R, b).astype(x.dtype)
291265

0 commit comments

Comments
 (0)