Skip to content

Commit 5f8f9fa

Browse files
committed
Flux img2img remote encode
1 parent 1001425 commit 5f8f9fa

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,10 @@ def __init__(
225225
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
226226
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
227227
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
228-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
228+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
229+
self.image_processor = VaeImageProcessor(
230+
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
231+
)
229232
self.tokenizer_max_length = (
230233
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
231234
)
@@ -634,7 +637,10 @@ def prepare_latents(
634637
return latents.to(device=device, dtype=dtype), latent_image_ids
635638

636639
image = image.to(device=device, dtype=dtype)
637-
image_latents = self._encode_vae_image(image=image, generator=generator)
640+
if image.shape[1] != self.latent_channels:
641+
image_latents = self._encode_vae_image(image=image, generator=generator)
642+
else:
643+
image_latents = image
638644
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
639645
# expand init_latents for batch_size
640646
additional_image_per_prompt = batch_size // image_latents.shape[0]

0 commit comments

Comments
 (0)