1111import math
1212import re
1313from functools import partial
14- from typing import Dict , List , Optional , Tuple
14+ from typing import Dict , Optional , Tuple
1515
1616import mlx .core as mx
1717import mlx .nn as nn
@@ -199,99 +199,75 @@ def _project_time(self, e: mx.array) -> mx.array:
199199
200200 def compute_time_embedding (self , t : mx .array ):
201201 """Compute time embeddings for TeaCache. Returns (t_emb, e0).
202- t_emb: [B , dim] (pre-projection, used by head)
203- e0: [B , 6*dim] (projected, used for block modulation)"""
202+ t_emb: [1 , dim] (pre-projection, used by head)
203+ e0: [1 , 6*dim] (projected, used for block modulation)"""
204204 t_emb = self ._embed_time (t )
205205 e0 = self ._project_time (t_emb )
206206 return t_emb , e0
207207
208208 def __call__ (
209209 self ,
210- x : List [ mx .array ] ,
210+ x : mx .array ,
211211 t : mx .array ,
212- context : List [ mx .array ] ,
213- context_lens : Optional [List [ int ] ] = None ,
212+ context : mx .array ,
213+ context_lens : Optional [int ] = None ,
214214 block_residual : Optional [mx .array ] = None ,
215215 precomputed_time : Optional [Tuple [mx .array , mx .array ]] = None ,
216216 clip_fea : Optional [mx .array ] = None ,
217- first_frame : Optional [List [ mx .array ] ] = None ,
218- ) -> List [ mx .array ] :
217+ first_frame : Optional [mx .array ] = None ,
218+ ) -> mx .array :
219219 """
220220 Forward pass for t2v and i2v.
221221
222222 Args:
223- x: List of input latents, each [C_in, F, H, W]
224- t: Timesteps [B ]
225- context: List of text embeddings, each [L, C_text]
226- context_lens: Actual context lengths (before padding)
223+ x: Input latent [ F, H, W, C_in] (channels-last)
224+ t: Timestep [1 ]
225+ context: Text embedding [L, C_text]
226+ context_lens: Actual context length (before padding)
227227 block_residual: Precomputed block residual for TeaCache skip
228228 precomputed_time: (t_emb, e0) tuple for TeaCache
229- clip_fea: CLIP image features [B , 257, 1280] (I2V only)
230- first_frame: List of image conditioning [C_cond, F, H, W] (I2V only).
229+ clip_fea: CLIP image features [1 , 257, 1280] (I2V only)
230+ first_frame: Image conditioning [F, H, W, C_cond ] (I2V only).
231231 Concatenated channel-wise with x before patchify (in_dim=36).
232232
233233 Returns:
234- List of output latents, each [C_out, F, H, W]
234+ Output latent [ F, H, W, C_out] (channels-last)
235235 """
236- B = len (x )
237-
238236 # Channel-concat image conditioning before patchify (I2V)
239237 if first_frame is not None :
240- x = [
241- mx .concatenate ([x_i , ff_i ], axis = 0 ) for x_i , ff_i in zip (x , first_frame )
242- ]
243-
244- # Patchify and embed
245- x_embedded = []
246- grid_sizes = []
247- seq_lens = []
248- for x_i in x :
249- x_i = x_i .transpose (1 , 2 , 3 , 0 )[None , :, :, :, :] # [1, F, H, W, C]
250- x_i = mx .conv3d (
251- x_i , self .patch_embedding_weight , stride = self .patch_size , padding = 0
238+ x = mx .concatenate ([x , first_frame ], axis = - 1 )
239+
240+ # Patchify: [F, H, W, C] -> [1, F, H, W, C] -> conv3d -> [1, Fp, Hp, Wp, dim]
241+ x = x [None ]
242+ x = mx .conv3d (x , self .patch_embedding_weight , stride = self .patch_size , padding = 0 )
243+ x = x + self .patch_embedding_bias [None , None , None , None , :]
244+ _ , Fp , Hp , Wp , _ = x .shape
245+ grid_sizes = [[Fp , Hp , Wp ]]
246+ x = x .reshape (1 , Fp * Hp * Wp , self .dim )
247+
248+ # Embed context: [L, C_text] -> [1, text_len, dim]
249+ if context .shape [0 ] < self .text_len :
250+ pad_len = self .text_len - context .shape [0 ]
251+ context = mx .concatenate (
252+ [context , mx .zeros ((pad_len , context .shape [1 ]))], axis = 0
252253 )
253- x_i = x_i + self .patch_embedding_bias [None , None , None , None , :]
254- _ , Fp , Hp , Wp , _ = x_i .shape
255- x_i = x_i .reshape (Fp * Hp * Wp , self .dim )
256- x_embedded .append (x_i )
257- grid_sizes .append ([Fp , Hp , Wp ])
258- seq_lens .append (Fp * Hp * Wp )
259-
260- # Pad and stack into batch
261- max_len = max (seq_lens )
262- x_padded = []
263- for x_i in x_embedded :
264- if x_i .shape [0 ] < max_len :
265- pad_len = max_len - x_i .shape [0 ]
266- x_i = mx .concatenate ([x_i , mx .zeros ((pad_len , self .dim ))], axis = 0 )
267- x_padded .append (x_i )
268- x = mx .stack (x_padded , axis = 0 )
269-
270- # Pad and embed context
271- context_padded = []
272- for c_i in context :
273- if c_i .shape [0 ] < self .text_len :
274- pad_len = self .text_len - c_i .shape [0 ]
275- c_i = mx .concatenate ([c_i , mx .zeros ((pad_len , c_i .shape [1 ]))], axis = 0 )
276- context_padded .append (c_i )
277- context_padded = mx .stack (context_padded , axis = 0 )
278- context = self ._embed_text (context_padded )
254+ context = self ._embed_text (context [None ])
279255
280256 # Prepend projected CLIP features to context (I2V)
281257 if clip_fea is not None :
282258 clip_proj = self ._embed_image (clip_fea )
283259 context = mx .concatenate ([clip_proj , context ], axis = 1 )
284260
285261 if context_lens is not None :
286- context_lens = mx .array (context_lens , dtype = mx .int32 )
262+ context_lens = mx .array ([ context_lens ] , dtype = mx .int32 )
287263
288- # Time embedding (per-sample, not per-patch)
264+ # Time embedding
289265 if precomputed_time is not None :
290266 t_emb , e = precomputed_time [0 ], precomputed_time [1 ]
291267 else :
292268 t_emb = self ._embed_time (t )
293269 e = self ._project_time (t_emb )
294- e = e .reshape (B , 6 , self .dim ) # [B, 6, dim]
270+ e = e .reshape (1 , 6 , self .dim )
295271
296272 # Transformer blocks
297273 if block_residual is not None :
@@ -301,31 +277,23 @@ def __call__(
301277 for i in range (self .num_layers ):
302278 block = getattr (self , f"block_{ i } " )
303279 x = block (x , e , grid_sizes , self .freqs , context , context_lens )
304- # Set by model.__call__ — read by pipeline for TeaCache caching
305280 self ._last_block_residual = x - x_in
306281
307282 # Output head
308283 x = self .head (x , t_emb )
309284
310- # Unpatchify
311- outputs = []
312- for i , (seq_len_i , grid_size ) in enumerate (zip (seq_lens , grid_sizes )):
313- x_i = x [i , :seq_len_i , :]
314- Fp , Hp , Wp = grid_size
315- pt , ph , pw = self .patch_size
316- x_i = rearrange (
317- x_i ,
318- "(Fp Hp Wp) (pt ph pw c) -> c (Fp pt) (Hp ph) (Wp pw)" ,
319- Fp = Fp ,
320- Hp = Hp ,
321- Wp = Wp ,
322- pt = pt ,
323- ph = ph ,
324- pw = pw ,
325- )
326- outputs .append (x_i )
327-
328- return outputs
285+ # Unpatchify: [1, seq_len, patch_features] -> [F, H, W, C]
286+ pt , ph , pw = self .patch_size
287+ return rearrange (
288+ x [0 ],
289+ "(Fp Hp Wp) (pt ph pw c) -> (Fp pt) (Hp ph) (Wp pw) c" ,
290+ Fp = Fp ,
291+ Hp = Hp ,
292+ Wp = Wp ,
293+ pt = pt ,
294+ ph = ph ,
295+ pw = pw ,
296+ )
329297
330298 @staticmethod
331299 def sanitize (weights : Dict [str , mx .array ]) -> Dict [str , mx .array ]:
@@ -334,6 +302,10 @@ def sanitize(weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
334302 for key , value in weights .items ():
335303 new_key = key
336304
305+ # Skip fp8 scale metadata from LightX2V quantized checkpoints
306+ if "weight_scale" in new_key :
307+ continue
308+
337309 # Remove model. prefix
338310 if new_key .startswith ("model." ):
339311 new_key = new_key [6 :]
0 commit comments