Skip to content

Commit fa9e5eb

Browse files
author
Gleb Sterkin
committed
reduce the amount of reshapes, make everything channel-last, changed list[mx.array] to mx.array where possible
1 parent b303a7b commit fa9e5eb

File tree

3 files changed

+108
-149
lines changed

3 files changed

+108
-149
lines changed

video/wan2.1/wan/model.py

Lines changed: 50 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import math
1212
import re
1313
from functools import partial
14-
from typing import Dict, List, Optional, Tuple
14+
from typing import Dict, Optional, Tuple
1515

1616
import mlx.core as mx
1717
import mlx.nn as nn
@@ -199,99 +199,75 @@ def _project_time(self, e: mx.array) -> mx.array:
199199

200200
def compute_time_embedding(self, t: mx.array):
201201
"""Compute time embeddings for TeaCache. Returns (t_emb, e0).
202-
t_emb: [B, dim] (pre-projection, used by head)
203-
e0: [B, 6*dim] (projected, used for block modulation)"""
202+
t_emb: [1, dim] (pre-projection, used by head)
203+
e0: [1, 6*dim] (projected, used for block modulation)"""
204204
t_emb = self._embed_time(t)
205205
e0 = self._project_time(t_emb)
206206
return t_emb, e0
207207

208208
def __call__(
209209
self,
210-
x: List[mx.array],
210+
x: mx.array,
211211
t: mx.array,
212-
context: List[mx.array],
213-
context_lens: Optional[List[int]] = None,
212+
context: mx.array,
213+
context_lens: Optional[int] = None,
214214
block_residual: Optional[mx.array] = None,
215215
precomputed_time: Optional[Tuple[mx.array, mx.array]] = None,
216216
clip_fea: Optional[mx.array] = None,
217-
first_frame: Optional[List[mx.array]] = None,
218-
) -> List[mx.array]:
217+
first_frame: Optional[mx.array] = None,
218+
) -> mx.array:
219219
"""
220220
Forward pass for t2v and i2v.
221221
222222
Args:
223-
x: List of input latents, each [C_in, F, H, W]
224-
t: Timesteps [B]
225-
context: List of text embeddings, each [L, C_text]
226-
context_lens: Actual context lengths (before padding)
223+
x: Input latent [F, H, W, C_in] (channels-last)
224+
t: Timestep [1]
225+
context: Text embedding [L, C_text]
226+
context_lens: Actual context length (before padding)
227227
block_residual: Precomputed block residual for TeaCache skip
228228
precomputed_time: (t_emb, e0) tuple for TeaCache
229-
clip_fea: CLIP image features [B, 257, 1280] (I2V only)
230-
first_frame: List of image conditioning [C_cond, F, H, W] (I2V only).
229+
clip_fea: CLIP image features [1, 257, 1280] (I2V only)
230+
first_frame: Image conditioning [F, H, W, C_cond] (I2V only).
231231
Concatenated channel-wise with x before patchify (in_dim=36).
232232
233233
Returns:
234-
List of output latents, each [C_out, F, H, W]
234+
Output latent [F, H, W, C_out] (channels-last)
235235
"""
236-
B = len(x)
237-
238236
# Channel-concat image conditioning before patchify (I2V)
239237
if first_frame is not None:
240-
x = [
241-
mx.concatenate([x_i, ff_i], axis=0) for x_i, ff_i in zip(x, first_frame)
242-
]
243-
244-
# Patchify and embed
245-
x_embedded = []
246-
grid_sizes = []
247-
seq_lens = []
248-
for x_i in x:
249-
x_i = x_i.transpose(1, 2, 3, 0)[None, :, :, :, :] # [1, F, H, W, C]
250-
x_i = mx.conv3d(
251-
x_i, self.patch_embedding_weight, stride=self.patch_size, padding=0
238+
x = mx.concatenate([x, first_frame], axis=-1)
239+
240+
# Patchify: [F, H, W, C] -> [1, F, H, W, C] -> conv3d -> [1, Fp, Hp, Wp, dim]
241+
x = x[None]
242+
x = mx.conv3d(x, self.patch_embedding_weight, stride=self.patch_size, padding=0)
243+
x = x + self.patch_embedding_bias[None, None, None, None, :]
244+
_, Fp, Hp, Wp, _ = x.shape
245+
grid_sizes = [[Fp, Hp, Wp]]
246+
x = x.reshape(1, Fp * Hp * Wp, self.dim)
247+
248+
# Embed context: [L, C_text] -> [1, text_len, dim]
249+
if context.shape[0] < self.text_len:
250+
pad_len = self.text_len - context.shape[0]
251+
context = mx.concatenate(
252+
[context, mx.zeros((pad_len, context.shape[1]))], axis=0
252253
)
253-
x_i = x_i + self.patch_embedding_bias[None, None, None, None, :]
254-
_, Fp, Hp, Wp, _ = x_i.shape
255-
x_i = x_i.reshape(Fp * Hp * Wp, self.dim)
256-
x_embedded.append(x_i)
257-
grid_sizes.append([Fp, Hp, Wp])
258-
seq_lens.append(Fp * Hp * Wp)
259-
260-
# Pad and stack into batch
261-
max_len = max(seq_lens)
262-
x_padded = []
263-
for x_i in x_embedded:
264-
if x_i.shape[0] < max_len:
265-
pad_len = max_len - x_i.shape[0]
266-
x_i = mx.concatenate([x_i, mx.zeros((pad_len, self.dim))], axis=0)
267-
x_padded.append(x_i)
268-
x = mx.stack(x_padded, axis=0)
269-
270-
# Pad and embed context
271-
context_padded = []
272-
for c_i in context:
273-
if c_i.shape[0] < self.text_len:
274-
pad_len = self.text_len - c_i.shape[0]
275-
c_i = mx.concatenate([c_i, mx.zeros((pad_len, c_i.shape[1]))], axis=0)
276-
context_padded.append(c_i)
277-
context_padded = mx.stack(context_padded, axis=0)
278-
context = self._embed_text(context_padded)
254+
context = self._embed_text(context[None])
279255

280256
# Prepend projected CLIP features to context (I2V)
281257
if clip_fea is not None:
282258
clip_proj = self._embed_image(clip_fea)
283259
context = mx.concatenate([clip_proj, context], axis=1)
284260

285261
if context_lens is not None:
286-
context_lens = mx.array(context_lens, dtype=mx.int32)
262+
context_lens = mx.array([context_lens], dtype=mx.int32)
287263

288-
# Time embedding (per-sample, not per-patch)
264+
# Time embedding
289265
if precomputed_time is not None:
290266
t_emb, e = precomputed_time[0], precomputed_time[1]
291267
else:
292268
t_emb = self._embed_time(t)
293269
e = self._project_time(t_emb)
294-
e = e.reshape(B, 6, self.dim) # [B, 6, dim]
270+
e = e.reshape(1, 6, self.dim)
295271

296272
# Transformer blocks
297273
if block_residual is not None:
@@ -301,31 +277,23 @@ def __call__(
301277
for i in range(self.num_layers):
302278
block = getattr(self, f"block_{i}")
303279
x = block(x, e, grid_sizes, self.freqs, context, context_lens)
304-
# Set by model.__call__ — read by pipeline for TeaCache caching
305280
self._last_block_residual = x - x_in
306281

307282
# Output head
308283
x = self.head(x, t_emb)
309284

310-
# Unpatchify
311-
outputs = []
312-
for i, (seq_len_i, grid_size) in enumerate(zip(seq_lens, grid_sizes)):
313-
x_i = x[i, :seq_len_i, :]
314-
Fp, Hp, Wp = grid_size
315-
pt, ph, pw = self.patch_size
316-
x_i = rearrange(
317-
x_i,
318-
"(Fp Hp Wp) (pt ph pw c) -> c (Fp pt) (Hp ph) (Wp pw)",
319-
Fp=Fp,
320-
Hp=Hp,
321-
Wp=Wp,
322-
pt=pt,
323-
ph=ph,
324-
pw=pw,
325-
)
326-
outputs.append(x_i)
327-
328-
return outputs
285+
# Unpatchify: [1, seq_len, patch_features] -> [F, H, W, C]
286+
pt, ph, pw = self.patch_size
287+
return rearrange(
288+
x[0],
289+
"(Fp Hp Wp) (pt ph pw c) -> (Fp pt) (Hp ph) (Wp pw) c",
290+
Fp=Fp,
291+
Hp=Hp,
292+
Wp=Wp,
293+
pt=pt,
294+
ph=ph,
295+
pw=pw,
296+
)
329297

330298
@staticmethod
331299
def sanitize(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
@@ -334,6 +302,10 @@ def sanitize(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
334302
for key, value in weights.items():
335303
new_key = key
336304

305+
# Skip fp8 scale metadata from LightX2V quantized checkpoints
306+
if "weight_scale" in new_key:
307+
continue
308+
337309
# Remove model. prefix
338310
if new_key.startswith("model."):
339311
new_key = new_key[6:]

0 commit comments

Comments
 (0)