Skip to content

Commit 5d6b78c

Browse files
committed
[refactor] refactor after review
1 parent 25fa97c commit 5d6b78c

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

src/diffusers/pipelines/flux/pipeline_flux_fill.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,10 @@ def __init__(
225225
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
226226
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
227227
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
228+
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
228229
self.mask_processor = VaeImageProcessor(
229230
vae_scale_factor=self.vae_scale_factor * 2,
230-
vae_latent_channels=self.vae.config.latent_channels,
231+
vae_latent_channels=latent_channels,
231232
do_normalize=False,
232233
do_binarize=True,
233234
do_convert_grayscale=True,
@@ -656,7 +657,7 @@ def disable_vae_tiling(self):
656657
"""
657658
self.vae.disable_tiling()
658659

659-
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
660+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxImg2ImgPipeline.prepare_latents
660661
def prepare_latents(
661662
self,
662663
image,
@@ -670,20 +671,24 @@ def prepare_latents(
670671
generator,
671672
latents=None,
672673
):
674+
if isinstance(generator, list) and len(generator) != batch_size:
675+
raise ValueError(
676+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
677+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
678+
)
679+
673680
# VAE applies 8x compression on images but we must also account for packing which requires
674681
# latent height and width to be divisible by 2.
675682
height = 2 * (int(height) // (self.vae_scale_factor * 2))
676683
width = 2 * (int(width) // (self.vae_scale_factor * 2))
677-
678684
shape = (batch_size, num_channels_latents, height, width)
685+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
686+
687+
if latents is not None:
688+
return latents.to(device=device, dtype=dtype), latent_image_ids
679689

680-
# if latents is not None:
681690
image = image.to(device=device, dtype=dtype)
682691
image_latents = self._encode_vae_image(image=image, generator=generator)
683-
684-
latent_image_ids = self._prepare_latent_image_ids(
685-
batch_size, height // 2, width // 2, device, dtype
686-
)
687692
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
688693
# expand init_latents for batch_size
689694
additional_image_per_prompt = batch_size // image_latents.shape[0]
@@ -695,19 +700,10 @@ def prepare_latents(
695700
else:
696701
image_latents = torch.cat([image_latents], dim=0)
697702

698-
if latents is None:
699-
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
700-
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
701-
else:
702-
noise = latents.to(device)
703-
latents = noise
704-
705-
noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
706-
image_latents = self._pack_latents(
707-
image_latents, batch_size, num_channels_latents, height, width
708-
)
703+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
704+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
709705
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
710-
return latents, noise, image_latents, latent_image_ids
706+
return latents, latent_image_ids
711707

712708
@property
713709
def guidance_scale(self):
@@ -866,7 +862,6 @@ def __call__(
866862
self._joint_attention_kwargs = joint_attention_kwargs
867863
self._interrupt = False
868864

869-
original_image = image
870865
init_image = self.image_processor.preprocess(image, height=height, width=width)
871866
init_image = init_image.to(dtype=torch.float32)
872867

@@ -935,7 +930,7 @@ def __call__(
935930

936931
# 5. Prepare latent variables
937932
num_channels_latents = self.vae.config.latent_channels
938-
latents, noise, image_latents, latent_image_ids = self.prepare_latents(
933+
latents, latent_image_ids = self.prepare_latents(
939934
init_image,
940935
latent_timestep,
941936
batch_size * num_images_per_prompt,

0 commit comments

Comments
 (0)