@@ -236,15 +236,42 @@ def forward(
236236        post_patch_width  =  width  //  p 
237237        original_context_length  =  post_patch_num_frames  *  post_patch_height  *  post_patch_width 
238238
239+         if  indices_latents  is  None :
240+             indices_latents  =  torch .arange (0 , num_frames ).unsqueeze (0 ).expand (batch_size , - 1 )
241+ 
242+         hidden_states  =  self .x_embedder (hidden_states )
243+         image_rotary_emb  =  self .rope (
244+             frame_indices = indices_latents , height = height , width = width , device = hidden_states .device 
245+         )
246+ 
247+         latents_clean , latents_history_2x , latents_history_4x  =  self .clean_x_embedder (
248+             latents_clean , latents_history_2x , latents_history_4x 
249+         )
250+ 
251+         if  latents_clean  is  not None  and  indices_latents_clean  is  not None :
252+             image_rotary_emb_clean  =  self .rope (
253+                 frame_indices = indices_latents_clean , height = height , width = width , device = hidden_states .device 
254+             )
255+         if  latents_history_2x  is  not None  and  indices_latents_history_2x  is  not None :
256+             image_rotary_emb_history_2x  =  self .rope (
257+                 frame_indices = indices_latents_history_2x , height = height , width = width , device = hidden_states .device 
258+             )
259+         if  latents_history_4x  is  not None  and  indices_latents_history_4x  is  not None :
260+             image_rotary_emb_history_4x  =  self .rope (
261+                 frame_indices = indices_latents_history_4x , height = height , width = width , device = hidden_states .device 
262+             )
263+ 
239264        hidden_states , image_rotary_emb  =  self ._pack_history_states (
240265            hidden_states ,
241-             indices_latents ,
242266            latents_clean ,
243267            latents_history_2x ,
244268            latents_history_4x ,
245-             indices_latents_clean ,
246-             indices_latents_history_2x ,
247-             indices_latents_history_4x ,
269+             image_rotary_emb ,
270+             image_rotary_emb_clean ,
271+             image_rotary_emb_history_2x ,
272+             image_rotary_emb_history_4x ,
273+             post_patch_height ,
274+             post_patch_width ,
248275        )
249276
250277        temb , _  =  self .time_text_embed (timestep , pooled_projections , guidance )
@@ -318,76 +345,48 @@ def forward(
318345    def  _pack_history_states (
319346        self ,
320347        hidden_states : torch .Tensor ,
321-         indices_latents : torch .Tensor ,
322348        latents_clean : Optional [torch .Tensor ] =  None ,
323349        latents_history_2x : Optional [torch .Tensor ] =  None ,
324350        latents_history_4x : Optional [torch .Tensor ] =  None ,
325-         indices_latents_clean : Optional [torch .Tensor ] =  None ,
326-         indices_latents_history_2x : Optional [torch .Tensor ] =  None ,
327-         indices_latents_history_4x : Optional [torch .Tensor ] =  None ,
351+         image_rotary_emb : Tuple [torch .Tensor , torch .Tensor ] =  None ,
352+         image_rotary_emb_clean : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
353+         image_rotary_emb_history_2x : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
354+         image_rotary_emb_history_4x : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
355+         height : int  =  None ,
356+         width : int  =  None ,
328357    ):
329-         batch_size , num_channels , num_frames , height , width  =  hidden_states .shape 
330-         if  indices_latents  is  None :
331-             indices_latents  =  torch .arange (0 , num_frames ).unsqueeze (0 ).expand (batch_size , - 1 )
332- 
333-         hidden_states  =  self .x_embedder (hidden_states )
334-         image_rotary_emb  =  self .rope (
335-             frame_indices = indices_latents , height = height , width = width , device = hidden_states .device 
336-         )
337358        image_rotary_emb  =  list (image_rotary_emb )  # convert tuple to list for in-place modification 
338-         pph , ppw  =  height  //  self .config .patch_size , width  //  self .config .patch_size 
339359
340-         latents_clean , latents_history_2x , latents_history_4x  =  self .clean_x_embedder (
341-             latents_clean , latents_history_2x , latents_history_4x 
342-         )
343- 
344-         if  latents_clean  is  not None :
360+         if  latents_clean  is  not None  and  image_rotary_emb_clean  is  not None :
345361            hidden_states  =  torch .cat ([latents_clean , hidden_states ], dim = 1 )
346- 
347-             image_rotary_emb_clean  =  self .rope (
348-                 frame_indices = indices_latents_clean , height = height , width = width , device = latents_clean .device 
349-             )
350362            image_rotary_emb [0 ] =  torch .cat ([image_rotary_emb_clean [0 ], image_rotary_emb [0 ]], dim = 0 )
351363            image_rotary_emb [1 ] =  torch .cat ([image_rotary_emb_clean [1 ], image_rotary_emb [1 ]], dim = 0 )
352364
353-         if  latents_history_2x  is  not None  and  indices_latents_history_2x  is  not None :
365+         if  latents_history_2x  is  not None  and  image_rotary_emb_history_2x  is  not None :
354366            hidden_states  =  torch .cat ([latents_history_2x , hidden_states ], dim = 1 )
355- 
356-             image_rotary_emb_history_2x  =  self .rope (
357-                 frame_indices = indices_latents_history_2x , height = height , width = width , device = latents_history_2x .device 
358-             )
359-             image_rotary_emb_history_2x  =  self ._pad_rotary_emb (
360-                 image_rotary_emb_history_2x , indices_latents_history_2x .size (0 ), pph , ppw , (2 , 2 , 2 )
361-             )
367+             image_rotary_emb_history_2x  =  self ._pad_rotary_emb (image_rotary_emb_history_2x , height , width , (2 , 2 , 2 ))
362368            image_rotary_emb [0 ] =  torch .cat ([image_rotary_emb_history_2x [0 ], image_rotary_emb [0 ]], dim = 0 )
363369            image_rotary_emb [1 ] =  torch .cat ([image_rotary_emb_history_2x [1 ], image_rotary_emb [1 ]], dim = 0 )
364370
365-         if  latents_history_4x  is  not None  and  indices_latents_history_4x  is  not None :
371+         if  latents_history_4x  is  not None  and  image_rotary_emb_history_4x  is  not None :
366372            hidden_states  =  torch .cat ([latents_history_4x , hidden_states ], dim = 1 )
367- 
368-             image_rotary_emb_history_4x  =  self .rope (
369-                 frame_indices = indices_latents_history_4x , height = height , width = width , device = latents_history_4x .device 
370-             )
371-             image_rotary_emb_history_4x  =  self ._pad_rotary_emb (
372-                 image_rotary_emb_history_4x , indices_latents_history_4x .size (0 ), pph , ppw , (4 , 4 , 4 )
373-             )
373+             image_rotary_emb_history_4x  =  self ._pad_rotary_emb (image_rotary_emb_history_4x , height , width , (4 , 4 , 4 ))
374374            image_rotary_emb [0 ] =  torch .cat ([image_rotary_emb_history_4x [0 ], image_rotary_emb [0 ]], dim = 0 )
375375            image_rotary_emb [1 ] =  torch .cat ([image_rotary_emb_history_4x [1 ], image_rotary_emb [1 ]], dim = 0 )
376376
377-         return  hidden_states , image_rotary_emb 
377+         return  hidden_states , tuple ( image_rotary_emb ) 
378378
379379    def  _pad_rotary_emb (
380380        self ,
381381        image_rotary_emb : Tuple [torch .Tensor ],
382-         num_frames : int ,
383382        height : int ,
384383        width : int ,
385384        kernel_size : Tuple [int , int , int ],
386385    ):
387386        # freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim 
388387        freqs_cos , freqs_sin  =  image_rotary_emb 
389-         freqs_cos  =  freqs_cos .unsqueeze (0 ).permute (0 , 2 , 1 ).unflatten (2 , (num_frames , height , width ))
390-         freqs_sin  =  freqs_sin .unsqueeze (0 ).permute (0 , 2 , 1 ).unflatten (2 , (num_frames , height , width ))
388+         freqs_cos  =  freqs_cos .unsqueeze (0 ).permute (0 , 2 , 1 ).unflatten (2 , (- 1 , height , width ))
389+         freqs_sin  =  freqs_sin .unsqueeze (0 ).permute (0 , 2 , 1 ).unflatten (2 , (- 1 , height , width ))
391390        freqs_cos  =  _pad_for_3d_conv (freqs_cos , kernel_size )
392391        freqs_sin  =  _pad_for_3d_conv (freqs_sin , kernel_size )
393392        freqs_cos  =  _center_down_sample_3d (freqs_cos , kernel_size )
0 commit comments