@@ -97,6 +97,20 @@ def calculate_shift(
9797    return  mu 
9898
9999
100+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 
101+ def  retrieve_latents (
102+     encoder_output : torch .Tensor , generator : Optional [torch .Generator ] =  None , sample_mode : str  =  "sample" 
103+ ):
104+     if  hasattr (encoder_output , "latent_dist" ) and  sample_mode  ==  "sample" :
105+         return  encoder_output .latent_dist .sample (generator )
106+     elif  hasattr (encoder_output , "latent_dist" ) and  sample_mode  ==  "argmax" :
107+         return  encoder_output .latent_dist .mode ()
108+     elif  hasattr (encoder_output , "latents" ):
109+         return  encoder_output .latents 
110+     else :
111+         raise  AttributeError ("Could not access latents of provided encoder_output" )
112+ 
113+ 
100114# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 
101115def  retrieve_timesteps (
102116    scheduler ,
@@ -512,7 +526,7 @@ def prepare_latents(
512526        shape  =  (batch_size , num_channels_latents , height , width )
513527
514528        if  latents  is  not None :
515-             latent_image_ids  =  self ._prepare_latent_image_ids (batch_size , height , width , device , dtype )
529+             latent_image_ids  =  self ._prepare_latent_image_ids (batch_size , height   //   2 , width   //   2 , device , dtype )
516530            return  latents .to (device = device , dtype = dtype ), latent_image_ids 
517531
518532        if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
@@ -773,7 +787,7 @@ def __call__(
773787            controlnet_blocks_repeat  =  False  if  self .controlnet .input_hint_block  is  None  else  True 
774788            if  self .controlnet .input_hint_block  is  None :
775789                # vae encode 
776-                 control_image  =  self .vae .encode (control_image ). latent_dist . sample ( )
790+                 control_image  =  retrieve_latents ( self .vae .encode (control_image ),  generator = generator )
777791                control_image  =  (control_image  -  self .vae .config .shift_factor ) *  self .vae .config .scaling_factor 
778792
779793                # pack 
@@ -811,7 +825,7 @@ def __call__(
811825
812826                if  self .controlnet .nets [0 ].input_hint_block  is  None :
813827                    # vae encode 
814-                     control_image_  =  self .vae .encode (control_image_ ). latent_dist . sample ( )
828+                     control_image_  =  retrieve_latents ( self .vae .encode (control_image_ ),  generator = generator )
815829                    control_image_  =  (control_image_  -  self .vae .config .shift_factor ) *  self .vae .config .scaling_factor 
816830
817831                    # pack 
0 commit comments