@@ -97,6 +97,20 @@ def calculate_shift(
9797    return  mu 
9898
9999
100+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 
101+ def  retrieve_latents (
102+     encoder_output : torch .Tensor , generator : Optional [torch .Generator ] =  None , sample_mode : str  =  "sample" 
103+ ):
104+     if  hasattr (encoder_output , "latent_dist" ) and  sample_mode  ==  "sample" :
105+         return  encoder_output .latent_dist .sample (generator )
106+     elif  hasattr (encoder_output , "latent_dist" ) and  sample_mode  ==  "argmax" :
107+         return  encoder_output .latent_dist .mode ()
108+     elif  hasattr (encoder_output , "latents" ):
109+         return  encoder_output .latents 
110+     else :
111+         raise  AttributeError ("Could not access latents of provided encoder_output" )
112+ 
113+ 
100114# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 
101115def  retrieve_timesteps (
102116    scheduler ,
@@ -512,7 +526,7 @@ def prepare_latents(
512526        shape  =  (batch_size , num_channels_latents , height , width )
513527
514528        if  latents  is  not None :
515-             latent_image_ids  =  self ._prepare_latent_image_ids (batch_size , height , width , device , dtype )
529+             latent_image_ids  =  self ._prepare_latent_image_ids (batch_size , height   //   2 , width   //   2 , device , dtype )
516530            return  latents .to (device = device , dtype = dtype ), latent_image_ids 
517531
518532        if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
@@ -772,7 +786,7 @@ def __call__(
772786            controlnet_blocks_repeat  =  False  if  self .controlnet .input_hint_block  is  None  else  True 
773787            if  self .controlnet .input_hint_block  is  None :
774788                # vae encode 
775-                 control_image  =  self .vae .encode (control_image ). latent_dist . sample ( )
789+                 control_image  =  retrieve_latents ( self .vae .encode (control_image ),  generator = generator )
776790                control_image  =  (control_image  -  self .vae .config .shift_factor ) *  self .vae .config .scaling_factor 
777791
778792                # pack 
@@ -810,7 +824,7 @@ def __call__(
810824
811825                if  self .controlnet .nets [0 ].input_hint_block  is  None :
812826                    # vae encode 
813-                     control_image_  =  self .vae .encode (control_image_ ). latent_dist . sample ( )
827+                     control_image_  =  retrieve_latents ( self .vae .encode (control_image_ ),  generator = generator )
814828                    control_image_  =  (control_image_  -  self .vae .config .shift_factor ) *  self .vae .config .scaling_factor 
815829
816830                    # pack 
0 commit comments