Skip to content

Commit 39f3437

Browse files
committed
Flux inpaint
1 parent 5f8f9fa commit 39f3437

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

src/diffusers/pipelines/flux/pipeline_flux_inpaint.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,13 @@ def __init__(
222222
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
223223
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
224224
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
225-
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
226-
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
225+
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
226+
self.image_processor = VaeImageProcessor(
227+
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
228+
)
227229
self.mask_processor = VaeImageProcessor(
228230
vae_scale_factor=self.vae_scale_factor * 2,
229-
vae_latent_channels=latent_channels,
231+
vae_latent_channels=self.latent_channels,
230232
do_normalize=False,
231233
do_binarize=True,
232234
do_convert_grayscale=True,
@@ -653,7 +655,10 @@ def prepare_latents(
653655
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
654656

655657
image = image.to(device=device, dtype=dtype)
656-
image_latents = self._encode_vae_image(image=image, generator=generator)
658+
if image.shape[1] != self.latent_channels:
659+
image_latents = self._encode_vae_image(image=image, generator=generator)
660+
else:
661+
image_latents = image
657662

658663
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
659664
# expand init_latents for batch_size
@@ -710,7 +715,9 @@ def prepare_mask_latents(
710715
else:
711716
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
712717

713-
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
718+
masked_image_latents = (
719+
masked_image_latents - self.vae.config.shift_factor
720+
) * self.vae.config.scaling_factor
714721

715722
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
716723
if mask.shape[0] < batch_size:

src/diffusers/utils/remote_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def prepare_encode(
367367
if shift_factor is not None:
368368
parameters["shift_factor"] = shift_factor
369369
if isinstance(image, torch.Tensor):
370-
data = safetensors.torch._tobytes(image, "tensor")
370+
data = safetensors.torch._tobytes(image.contiguous(), "tensor")
371371
parameters["shape"] = list(image.shape)
372372
parameters["dtype"] = str(image.dtype).split(".")[-1]
373373
else:

0 commit comments

Comments
 (0)