@@ -657,7 +657,7 @@ def prepare_latents(
657657
658658                    rope_interpolation_scale  =  (
659659                        rope_interpolation_scale  *  
660-                         torch .tensor ([self .vae_temporal_compression_ratio , self .vae_spatial_compression_ratio , self .vae_spatial_compression_ratio ], device = latent_coords .device )[None , :, None ]
660+                         torch .tensor ([self .vae_temporal_compression_ratio , self .vae_spatial_compression_ratio , self .vae_spatial_compression_ratio ], device = rope_interpolation_scale .device )[None , :, None ]
661661                    )
662662                    rope_interpolation_scale [:, 0 ] =  (rope_interpolation_scale [:, 0 ] +  1  -  self .vae_temporal_compression_ratio ).clamp (min = 0 )               
663663                    rope_interpolation_scale [:, 0 ] +=  condition .frame_index 
@@ -675,17 +675,16 @@ def prepare_latents(
675675        latents , rope_interpolation_scale  =  self ._pack_latents (
676676            latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size , device 
677677        )
678+         conditioning_mask  =  condition_latent_frames_mask .gather (
679+             1 , rope_interpolation_scale [:, 0 ]
680+         )
678681
679682        rope_interpolation_scale  =  (
680683            rope_interpolation_scale 
681-             *  torch .tensor ([self .vae_temporal_compression_ratio , self .vae_spatial_compression_ratio , self .vae_spatial_compression_ratio ], device = latent_coords .device )[None , :, None ]
684+             *  torch .tensor ([self .vae_temporal_compression_ratio , self .vae_spatial_compression_ratio , self .vae_spatial_compression_ratio ], device = rope_interpolation_scale .device )[None , :, None ]
682685        )
683686        rope_interpolation_scale [:, 0 ] =  (rope_interpolation_scale [:, 0 ] +  1  -  self .vae_temporal_compression_ratio ).clamp (min = 0 )
684687
685-         conditioning_mask  =  condition_latent_frames_mask .gather (
686-             1 , latent_coords [:, 0 ]
687-         )
688- 
689688        if  len (extra_conditioning_latents ) >  0 :
690689            latents  =  torch .cat ([* extra_conditioning_latents , latents ], dim = 1 )
691690            rope_interpolation_scale  =  torch .cat (
0 commit comments