diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 08d0b44d613d..ef25b8c024a8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -710,6 +710,8 @@ def prepare_latents( batch_size = batch_size * num_images_per_prompt + vae_original_dtype = self.vae.dtype + if image.shape[1] == 4: init_latents = image @@ -742,7 +744,7 @@ def prepare_latents( init_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: - self.vae.to(dtype) + self.vae.to(vae_original_dtype) init_latents = init_latents.to(dtype) if latents_mean is not None and latents_std is not None: @@ -1459,10 +1461,13 @@ def denoising_value_valid(dnv): if needs_upcasting: self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - elif latents.dtype != self.vae.dtype: + + if latents.dtype != self.vae.dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 self.vae = self.vae.to(latents.dtype) + else: + latents = latents.to(self.vae.dtype) # unscale/denormalize the latents # denormalize with the mean and std if available and not None