Skip to content

Commit cb43412

Browse files
committed
update prepare_latents from flux.img2img pipeline
1 parent e87c9eb commit cb43412

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/diffusers/pipelines/flux/pipeline_flux_fill.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,11 @@ def __init__(
224224
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
225225
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
226226
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
227-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
228-
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
227+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
228+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2,vae_latent_channels=self.latent_channels)
229229
self.mask_processor = VaeImageProcessor(
230230
vae_scale_factor=self.vae_scale_factor * 2,
231-
vae_latent_channels=latent_channels,
231+
vae_latent_channels=self.latent_channels,
232232
do_normalize=False,
233233
do_binarize=True,
234234
do_convert_grayscale=True,
@@ -686,7 +686,10 @@ def prepare_latents(
686686
return latents.to(device=device, dtype=dtype), latent_image_ids
687687

688688
image = image.to(device=device, dtype=dtype)
689-
image_latents = self._encode_vae_image(image=image, generator=generator)
689+
if image.shape[1] != self.latent_channels:
690+
image_latents = self._encode_vae_image(image=image, generator=generator)
691+
else:
692+
image_latents = image
690693
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
691694
# expand init_latents for batch_size
692695
additional_image_per_prompt = batch_size // image_latents.shape[0]

0 commit comments

Comments
 (0)