@@ -419,6 +419,9 @@ def check_inputs(
419419 self ,
420420 prompt ,
421421 prompt_2 ,
422+ inverted_latents ,
423+ image_latents ,
424+ latent_image_ids ,
422425 height ,
423426 width ,
424427 start_timestep ,
@@ -467,6 +470,10 @@ def check_inputs(
467470 if max_sequence_length is not None and max_sequence_length > 512 :
468471 raise ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length } " )
469472
473+ if inverted_latents is not None and (image_latents is None or latent_image_ids is None ):
474+ raise ValueError (
475+ "If `inverted_latents` are provided, `image_latents` and `latent_image_ids` also have to be passed. "
476+ )
470477 # check start_timestep and stop_timestep
471478 if start_timestep < 0 or start_timestep > stop_timestep :
472479 raise ValueError (f"`start_timestep` should be in [0, stop_timestep] but is { start_timestep } " )
@@ -536,7 +543,7 @@ def disable_vae_tiling(self):
536543 """
537544 self .vae .disable_tiling ()
538545
539- def prepare_latents (
546+ def prepare_latents_inversion (
540547 self ,
541548 batch_size ,
542549 num_channels_latents ,
@@ -555,6 +562,41 @@ def prepare_latents(
555562
556563 return latents , latent_image_ids
557564
565+ def prepare_latents (
566+ self ,
567+ batch_size ,
568+ num_channels_latents ,
569+ height ,
570+ width ,
571+ dtype ,
572+ device ,
573+ generator ,
574+ latents = None ,
575+ ):
576+ # VAE applies 8x compression on images but we must also account for packing which requires
577+ # latent height and width to be divisible by 2.
578+ height = 2 * (int (height ) // (self .vae_scale_factor * 2 ))
579+ width = 2 * (int (width ) // (self .vae_scale_factor * 2 ))
580+
581+ shape = (batch_size , num_channels_latents , height , width )
582+
583+ if latents is not None :
584+ latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
585+ return latents .to (device = device , dtype = dtype ), latent_image_ids
586+
587+ if isinstance (generator , list ) and len (generator ) != batch_size :
588+ raise ValueError (
589+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
590+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
591+ )
592+
593+ latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
594+ latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
595+
596+ latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
597+
598+ return latents , latent_image_ids
599+
558600 # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
559601 def get_timesteps (self , num_inference_steps , strength = 1.0 ):
560602 # get the original timestep using init_timestep
@@ -588,11 +630,11 @@ def interrupt(self):
588630 @replace_example_docstring (EXAMPLE_DOC_STRING )
589631 def __call__ (
590632 self ,
591- latents : Optional [torch .FloatTensor ] = None ,
592- image_latents : Optional [torch .FloatTensor ] = None ,
593- latent_image_ids : Optional [torch .FloatTensor ] = None ,
594633 prompt : Union [str , List [str ]] = None ,
595634 prompt_2 : Optional [Union [str , List [str ]]] = None ,
635+ inverted_latents : Optional [torch .FloatTensor ] = None ,
636+ image_latents : Optional [torch .FloatTensor ] = None ,
637+ latent_image_ids : Optional [torch .FloatTensor ] = None ,
596638 height : Optional [int ] = None ,
597639 width : Optional [int ] = None ,
598640 eta : float = 1.0 ,
@@ -604,6 +646,7 @@ def __call__(
604646 guidance_scale : float = 3.5 ,
605647 num_images_per_prompt : Optional [int ] = 1 ,
606648 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
649+ latents : Optional [torch .FloatTensor ] = None ,
607650 prompt_embeds : Optional [torch .FloatTensor ] = None ,
608651 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
609652 output_type : Optional [str ] = "pil" ,
@@ -693,6 +736,9 @@ def __call__(
693736 self .check_inputs (
694737 prompt ,
695738 prompt_2 ,
739+ inverted_latents ,
740+ image_latents ,
741+ latent_image_ids ,
696742 height ,
697743 width ,
698744 start_timestep ,
@@ -706,6 +752,7 @@ def __call__(
706752 self ._guidance_scale = guidance_scale
707753 self ._joint_attention_kwargs = joint_attention_kwargs
708754 self ._interrupt = False
755+ do_rf_inversion = inverted_latents is not None
709756
710757 # 2. Define call parameters
711758 if prompt is not None and isinstance (prompt , str ):
@@ -737,6 +784,19 @@ def __call__(
737784
738785 # 4. Prepare latent variables
739786 num_channels_latents = self .transformer .config .in_channels // 4
787+ if do_rf_inversion :
788+ latents = inverted_latents
789+ else :
790+ latents , latent_image_ids = self .prepare_latents (
791+ batch_size * num_images_per_prompt ,
792+ num_channels_latents ,
793+ height ,
794+ width ,
795+ prompt_embeds .dtype ,
796+ device ,
797+ generator ,
798+ latents ,
799+ )
740800
741801 # 5. Prepare timesteps
742802 sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
@@ -769,9 +829,11 @@ def __call__(
769829 else :
770830 guidance = None
771831
832+ if do_rf_inversion :
833+ y_0 = image_latents .clone ()
772834 # 6. Denoising loop
773835 with self .progress_bar (total = num_inference_steps ) as progress_bar :
774- y_0 = image_latents . clone ()
836+
775837 for i , t in enumerate (timesteps ):
776838 t_i = 1 - t / 1000
777839 dt = torch .tensor (1 / (len (timesteps ) - 1 ), device = device )
@@ -782,7 +844,7 @@ def __call__(
782844 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
783845 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
784846
785- v_t = - self .transformer (
847+ noise_pred = self .transformer (
786848 hidden_states = latents ,
787849 timestep = timestep / 1000 ,
788850 guidance = guidance ,
@@ -794,18 +856,25 @@ def __call__(
794856 return_dict = False ,
795857 )[0 ]
796858
797- v_t_cond = (y_0 - latents ) / (1 - t_i )
798- eta_t = eta if start_timestep <= i < stop_timestep else 0.0
799- if start_timestep <= i < stop_timestep :
800- # controlled vector field
801- v_hat_t = v_t + eta * (v_t_cond - v_t )
859+ if do_rf_inversion :
860+ v_t = - noise_pred
802861
803- else :
804- v_hat_t = v_t
805- # SDE Eq: 17
862+ v_t_cond = (y_0 - latents ) / (1 - t_i )
863+ eta_t = eta if start_timestep <= i < stop_timestep else 0.0
864+ if start_timestep <= i < stop_timestep :
865+ # controlled vector field
866+ v_hat_t = v_t + eta * (v_t_cond - v_t )
806867
807- latents_dtype = latents .dtype
808- latents = latents + v_hat_t * (sigmas [i ] - sigmas [i + 1 ])
868+ else :
869+ v_hat_t = v_t
870+ # SDE Eq: 17
871+
872+ latents_dtype = latents .dtype
873+ latents = latents + v_hat_t * (sigmas [i ] - sigmas [i + 1 ])
874+ else :
875+ # compute the previous noisy sample x_t -> x_t-1
876+ latents_dtype = latents .dtype
877+ latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
809878
810879 if latents .dtype != latents_dtype :
811880 if torch .backends .mps .is_available ():
@@ -898,7 +967,7 @@ def invert(
898967
899968 # 1. prepare image
900969 image_latents , _ = self .encode_image (image , height = height , width = width , dtype = dtype )
901- image_latents , latent_image_ids = self .prepare_latents (
970+ image_latents , latent_image_ids = self .prepare_latents_inversion (
902971 batch_size , num_channels_latents , height , width , dtype , device , image_latents
903972 )
904973
0 commit comments