@@ -513,10 +513,10 @@ def _prepare_non_first_frame_conditioning(
513513        frame_index : int ,
514514        strength : float ,
515515        num_prefix_latent_frames : int  =  2 ,
516-         prefix_latents_mode : str  =  "soft " ,
516+         prefix_latents_mode : str  =  "concat " ,
517517        prefix_soft_conditioning_strength : float  =  0.15 ,
518518    ) ->  Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
519-         num_latent_frames  =  latents .size (2 )
519+         num_latent_frames  =  condition_latents .size (2 )
520520
521521        if  num_latent_frames  <  num_prefix_latent_frames :
522522            raise  ValueError (
@@ -602,7 +602,7 @@ def prepare_latents(
602602        extra_conditioning_num_latents  =  (
603603            0   # Number of extra conditioning latents added (should be removed before decoding) 
604604        )
605-         condition_latent_frames_mask  =  torch .zeros ((batch_size , num_latent_frames ), device = device , dtype = dtype )
605+         condition_latent_frames_mask  =  torch .zeros ((batch_size , num_latent_frames ), device = device , dtype = torch . float32 )
606606
607607        for  condition  in  conditions :
608608            if  condition .image  is  not None :
@@ -632,7 +632,7 @@ def prepare_latents(
632632                    latents [:, :, :num_cond_frames ], condition_latents , condition .strength 
633633                )
634634                condition_latent_frames_mask [:, :num_cond_frames ] =  condition .strength 
635-              # YiYi TODO: code path not tested 
635+ 
636636            else :
637637                if  num_data_frames  >  1 :
638638                    (
@@ -651,47 +651,41 @@ def prepare_latents(
651651                    noise  =  randn_tensor (condition_latents .shape , generator = generator , device = device , dtype = dtype )
652652                    condition_latents  =  torch .lerp (noise , condition_latents , condition .strength )
653653                    c_nlf  =  condition_latents .shape [2 ]
654-                     condition_latents , condition_latent_coords  =  self ._pack_latents (
654+                     condition_latents , rope_interpolation_scale  =  self ._pack_latents (
655655                        condition_latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size , device 
656656                    )
657+ 
658+                     rope_interpolation_scale  =  (
659+                         rope_interpolation_scale  *  
660+                         torch .tensor ([self .vae_temporal_compression_ratio , self .vae_spatial_compression_ratio , self .vae_spatial_compression_ratio ], device = latent_coords .device )[None , :, None ]
661+                     )
662+                     rope_interpolation_scale [:, 0 ] =  (rope_interpolation_scale [:, 0 ] +  1  -  self .vae_temporal_compression_ratio ).clamp (min = 0 )               
663+                     rope_interpolation_scale [:, 0 ] +=  condition .frame_index 
664+ 
657665                    conditioning_mask  =  torch .full (
658666                        condition_latents .shape [:2 ], condition .strength , device = device , dtype = dtype 
659667                    )
660668
661-                     rope_interpolation_scale  =  [
662-                         # TODO!!! This is incorrect: the frame index needs to added AFTER multiplying the interpolation 
663-                         # scale with the grid. 
664-                         (self .vae_temporal_compression_ratio  +  condition .frame_index ) /  frame_rate ,
665-                         self .vae_spatial_compression_ratio ,
666-                         self .vae_spatial_compression_ratio ,
667-                     ]
668-                     rope_interpolation_scale  =  (
669-                         torch .tensor (rope_interpolation_scale , device = device , dtype = dtype )
670-                         .view (- 1 , 1 , 1 , 1 , 1 )
671-                         .repeat (1 , 1 , c_nlf , latent_height , latent_width )
672-                     )
673669                    extra_conditioning_num_latents  +=  condition_latents .size (1 )
674670
675671                    extra_conditioning_latents .append (condition_latents )
676672                    extra_conditioning_rope_interpolation_scales .append (rope_interpolation_scale )
677673                    extra_conditioning_mask .append (conditioning_mask )
678674
679-         latents , latent_coords  =  self ._pack_latents (
675+         latents , rope_interpolation_scale  =  self ._pack_latents (
680676            latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size , device 
681677        )
682-         pixel_coords  =  (
683-             latent_coords 
678+ 
679+         rope_interpolation_scale  =  (
680+             rope_interpolation_scale 
684681            *  torch .tensor ([self .vae_temporal_compression_ratio , self .vae_spatial_compression_ratio , self .vae_spatial_compression_ratio ], device = latent_coords .device )[None , :, None ]
685682        )
686-         pixel_coords [:, 0 ] =  (pixel_coords [:, 0 ] +  1  -  self .vae_temporal_compression_ratio ).clamp (min = 0 )
687- 
688-         rope_interpolation_scale  =  pixel_coords 
683+         rope_interpolation_scale [:, 0 ] =  (rope_interpolation_scale [:, 0 ] +  1  -  self .vae_temporal_compression_ratio ).clamp (min = 0 )
689684
690685        conditioning_mask  =  condition_latent_frames_mask .gather (
691686            1 , latent_coords [:, 0 ]
692687        )
693688
694-         # YiYi TODO: code path not tested 
695689        if  len (extra_conditioning_latents ) >  0 :
696690            latents  =  torch .cat ([* extra_conditioning_latents , latents ], dim = 1 )
697691            rope_interpolation_scale  =  torch .cat (
0 commit comments