@@ -97,6 +97,20 @@ def calculate_shift(
9797 return mu
9898
9999
100+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
101+ def retrieve_latents (
102+ encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
103+ ):
104+ if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
105+ return encoder_output .latent_dist .sample (generator )
106+ elif hasattr (encoder_output , "latent_dist" ) and sample_mode == "argmax" :
107+ return encoder_output .latent_dist .mode ()
108+ elif hasattr (encoder_output , "latents" ):
109+ return encoder_output .latents
110+ else :
111+ raise AttributeError ("Could not access latents of provided encoder_output" )
112+
113+
100114# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
101115def retrieve_timesteps (
102116 scheduler ,
@@ -512,7 +526,7 @@ def prepare_latents(
512526 shape = (batch_size , num_channels_latents , height , width )
513527
514528 if latents is not None :
515- latent_image_ids = self ._prepare_latent_image_ids (batch_size , height , width , device , dtype )
529+ latent_image_ids = self ._prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
516530 return latents .to (device = device , dtype = dtype ), latent_image_ids
517531
518532 if isinstance (generator , list ) and len (generator ) != batch_size :
@@ -773,7 +787,7 @@ def __call__(
773787 controlnet_blocks_repeat = False if self .controlnet .input_hint_block is None else True
774788 if self .controlnet .input_hint_block is None :
775789 # vae encode
776- control_image = self .vae .encode (control_image ). latent_dist . sample ( )
790+ control_image = retrieve_latents ( self .vae .encode (control_image ), generator = generator )
777791 control_image = (control_image - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
778792
779793 # pack
@@ -811,7 +825,7 @@ def __call__(
811825
812826 if self .controlnet .nets [0 ].input_hint_block is None :
813827 # vae encode
814- control_image_ = self .vae .encode (control_image_ ). latent_dist . sample ( )
828+ control_image_ = retrieve_latents ( self .vae .encode (control_image_ ), generator = generator )
815829 control_image_ = (control_image_ - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
816830
817831 # pack
0 commit comments