@@ -456,9 +456,8 @@ def prepare_latents(
456456        device : Optional [torch .device ] =  None ,
457457        generator : Optional [torch .Generator ] =  None ,
458458        latents : Optional [torch .Tensor ] =  None ,
459-         sigma :  torch .Tensor  =  1.0 ,
459+         timestep :  Optional [ torch .Tensor ]  =  None ,
460460    ) ->  torch .Tensor :
461-         # TODO: do we need the `conditioning_mask` here? I think `conditioning_mask` should be all ones. 
462461        height  =  height  //  self .vae_spatial_compression_ratio 
463462        width  =  width  //  self .vae_spatial_compression_ratio 
464463
@@ -471,14 +470,9 @@ def prepare_latents(
471470            (num_frames  -  1 ) //  self .vae_temporal_compression_ratio  +  1  if  latents  is  None  else  latents .size (2 )
472471        )
473472        shape  =  (batch_size , num_channels_latents , num_frames , height , width )
474-         mask_shape  =  (batch_size , 1 , num_frames , height , width )
475473
476474        if  latents  is  not None :
477-             conditioning_mask  =  latents .new_ones (shape )
478-             conditioning_mask  =  self ._pack_latents (
479-                 conditioning_mask , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size 
480-             )
481-             return  latents .to (device = device , dtype = dtype ), conditioning_mask 
475+             return  latents .to (device = device , dtype = dtype )
482476
483477        if  isinstance (generator , list ):
484478            if  len (generator ) !=  batch_size :
@@ -491,30 +485,21 @@ def prepare_latents(
491485                retrieve_latents (self .vae .encode (video [i ].unsqueeze (0 ).permute (0 , 2 , 1 , 3 , 4 )), generator [i ])
492486                for  i  in  range (batch_size )
493487            ]
494-         else : # `premute()` because we want `batch_size, num_channels, num_frames, height, width` 
488+         else :   # `premute()` because we want `batch_size, num_channels, num_frames, height, width` 
495489            init_latents  =  [
496490                retrieve_latents (self .vae .encode (vid .unsqueeze (0 ).permute (0 , 2 , 1 , 3 , 4 )), generator ) for  vid  in  video 
497491            ]
498492
499493        init_latents  =  torch .cat (init_latents , dim = 0 ).to (dtype )
500494        init_latents  =  self ._normalize_latents (init_latents , self .vae .latents_mean , self .vae .latents_std )
501-         # `ones()` as we want to condition on all? 
502-         conditioning_mask  =  torch .ones (mask_shape , device = device , dtype = dtype )
503- 
504495        noise  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
505-         # TODO: consider adding the noise w.r.t the flow equation? CogVideoX vid2vid 
506-         # adds this noise with `add_noise()`. 
507-         latents  =  (1  -  sigma ) *  init_latents  +  sigma  *  noise 
508-         # latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) 
509- 
510-         conditioning_mask  =  self ._pack_latents (
511-             conditioning_mask , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size 
512-         ).squeeze (- 1 )
496+ 
497+         latents  =  self .scheduler .scale_noise (sample = init_latents , timestep = timestep , noise = noise )
513498        latents  =  self ._pack_latents (
514499            latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size 
515500        )
516501
517-         return  latents ,  conditioning_mask 
502+         return  latents 
518503
519504    @property  
520505    def  guidance_scale (self ):
@@ -536,6 +521,16 @@ def attention_kwargs(self):
536521    def  interrupt (self ):
537522        return  self ._interrupt 
538523
524+     # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps 
525+     def  get_timesteps (self , num_inference_steps , timesteps , strength , device ):
526+         # get the original timestep using init_timestep 
527+         init_timestep  =  min (int (num_inference_steps  *  strength ), num_inference_steps )
528+ 
529+         t_start  =  max (num_inference_steps  -  init_timestep , 0 )
530+         timesteps  =  timesteps [t_start  *  self .scheduler .order  :]
531+ 
532+         return  timesteps , num_inference_steps  -  t_start 
533+ 
539534    @torch .no_grad () 
540535    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
541536    def  __call__ (
@@ -549,6 +544,7 @@ def __call__(
549544        frame_rate : int  =  25 ,
550545        num_inference_steps : int  =  50 ,
551546        timesteps : List [int ] =  None ,
547+         strength : float  =  0.8 ,
552548        guidance_scale : float  =  3 ,
553549        num_videos_per_prompt : Optional [int ] =  1 ,
554550        generator : Optional [Union [torch .Generator , List [torch .Generator ]]] =  None ,
@@ -585,6 +581,7 @@ def __call__(
585581                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 
586582                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 
587583                passed will be used. Must be in descending order. 
584+             strength: TODO 
588585            guidance_scale (`float`, defaults to `3 `): 
589586                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 
590587                `guidance_scale` is defined as `w` of equation 2. of [Imagen 
@@ -643,7 +640,7 @@ def __call__(
643640            callback_on_step_end_tensor_inputs  =  callback_on_step_end .tensor_inputs 
644641
645642        # 1. Check inputs. Raise error if not correct 
646-         # TODO: check for the `video` 
643+         # TODO: check for the `video`, `strength`  
647644        self .check_inputs (
648645            prompt = prompt ,
649646            height = height ,
@@ -726,15 +723,14 @@ def __call__(
726723            sigmas = sigmas ,
727724            mu = mu ,
728725        )
726+         timesteps , num_inference_steps  =  self .get_timesteps (num_inference_steps , timesteps , strength , device )
727+         latent_timestep  =  timesteps [:1 ].repeat (batch_size  *  num_videos_per_prompt )
729728        num_warmup_steps  =  max (len (timesteps ) -  num_inference_steps  *  self .scheduler .order , 0 )
730-         latent_sigma  =  torch .tensor (
731-             sigmas [:1 ].repeat (batch_size  *  num_videos_per_prompt ), dtype = prompt_embeds .dtype , device = device 
732-         )
733729        self ._num_timesteps  =  len (timesteps )
734730
735731        # 6. Prepare latent variables 
736732        num_channels_latents  =  self .transformer .config .in_channels 
737-         latents ,  conditioning_mask  =  self .prepare_latents (
733+         latents  =  self .prepare_latents (
738734            video ,
739735            batch_size  *  num_videos_per_prompt ,
740736            num_channels_latents ,
@@ -745,10 +741,8 @@ def __call__(
745741            device ,
746742            generator ,
747743            latents ,
748-             sigma = latent_sigma ,
744+             timestep = latent_timestep ,
749745        )
750-         if  self .do_classifier_free_guidance :
751-             conditioning_mask  =  torch .cat ([conditioning_mask , conditioning_mask ])
752746
753747        # 7. Prepare micro-conditions 
754748        latent_frame_rate  =  frame_rate  /  self .vae_temporal_compression_ratio 
@@ -769,8 +763,6 @@ def __call__(
769763
770764                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 
771765                timestep  =  t .expand (latent_model_input .shape [0 ])
772-                 # timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) 
773-                 timestep  =  timestep .unsqueeze (- 1 )
774766                noise_pred  =  self .transformer (
775767                    hidden_states = latent_model_input ,
776768                    encoder_hidden_states = prompt_embeds ,
@@ -788,7 +780,6 @@ def __call__(
788780                if  self .do_classifier_free_guidance :
789781                    noise_pred_uncond , noise_pred_text  =  noise_pred .chunk (2 )
790782                    noise_pred  =  noise_pred_uncond  +  self .guidance_scale  *  (noise_pred_text  -  noise_pred_uncond )
791-                     timestep , _  =  timestep .chunk (2 )
792783
793784                # compute the previous noisy sample x_t -> x_t-1 
794785                noise_pred  =  self ._unpack_latents (
@@ -807,12 +798,7 @@ def __call__(
807798                    self .transformer_spatial_patch_size ,
808799                    self .transformer_temporal_patch_size ,
809800                )
810- 
811-                 noise_pred  =  noise_pred [:, :, 1 :]
812-                 noise_latents  =  latents [:, :, 1 :]
813-                 pred_latents  =  self .scheduler .step (noise_pred , t , noise_latents , return_dict = False )[0 ]
814- 
815-                 latents  =  torch .cat ([latents [:, :, :1 ], pred_latents ], dim = 2 )
801+                 latents  =  self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
816802                latents  =  self ._pack_latents (
817803                    latents , self .transformer_spatial_patch_size , self .transformer_temporal_patch_size 
818804                )
0 commit comments