@@ -588,6 +588,9 @@ def interrupt(self):
588588 @replace_example_docstring (EXAMPLE_DOC_STRING )
589589 def __call__ (
590590 self ,
591+ latents : Optional [torch .FloatTensor ] = None ,
592+ image_latents : Optional [torch .FloatTensor ] = None ,
593+ latent_image_ids : Optional [torch .FloatTensor ] = None ,
591594 prompt : Union [str , List [str ]] = None ,
592595 prompt_2 : Optional [Union [str , List [str ]]] = None ,
593596 height : Optional [int ] = None ,
@@ -601,7 +604,6 @@ def __call__(
601604 guidance_scale : float = 3.5 ,
602605 num_images_per_prompt : Optional [int ] = 1 ,
603606 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
604- latents : Optional [torch .FloatTensor ] = None ,
605607 prompt_embeds : Optional [torch .FloatTensor ] = None ,
606608 pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
607609 output_type : Optional [str ] = "pil" ,
@@ -735,9 +737,6 @@ def __call__(
735737
736738 # 4. Prepare latent variables
737739 num_channels_latents = self .transformer .config .in_channels // 4
738- latents = self .inverted_latents
739- latent_image_ids = self .latent_image_ids
740- image_latents = self .image_latents
741740
742741 # 5. Prepare timesteps
743742 sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
@@ -859,7 +858,6 @@ def invert(
859858 width : Optional [int ] = None ,
860859 timesteps : List [int ] = None ,
861860 dtype : Optional [torch .dtype ] = None ,
862- generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
863861 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
864862 ):
865863 r"""
@@ -903,7 +901,6 @@ def invert(
903901 image_latents , latent_image_ids = self .prepare_latents (
904902 batch_size , num_channels_latents , height , width , dtype , device , image_latents
905903 )
906- self .image_latents = image_latents .clone ()
907904
908905 # 2. prepare timesteps
909906 sigmas = np .linspace (1.0 , 1 / num_inversion_steps , num_inversion_steps )
@@ -974,7 +971,5 @@ def invert(
974971 Y_t = Y_t + u_hat_t_i * (sigmas [i ] - sigmas [i + 1 ])
975972 progress_bar .update ()
976973
977- self .inverted_latents = Y_t
978- self .latent_image_ids = latent_image_ids
979-
980- return self .image_latents , Y_t , latent_image_ids
974+ # return the inverted latents (start point for the denoising loop), encoded image & latent image ids
975+ return Y_t , image_latents , latent_image_ids
0 commit comments