Skip to content

Commit c3bc32e

Browse files
committed
fix from PR comments
1 parent 56f5b79 commit c3bc32e

File tree

1 file changed

+32
-31
lines changed

1 file changed

+32
-31
lines changed

src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,8 @@ def prepare_image(
660660

661661
def prepare_mask_latents(
662662
self,
663-
mask,
664-
masked_image,
663+
image,
664+
mask_image,
665665
batch_size,
666666
num_channels_latents,
667667
num_images_per_prompt,
@@ -673,34 +673,40 @@ def prepare_mask_latents(
673673
):
674674
# VAE applies 8x compression on images but we must also account for packing which requires
675675
# latent height and width to be divisible by 2.
676+
image = self.image_processor.preprocess(image, height=height, width=width)
677+
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
678+
679+
masked_image = image * (1 - mask_image)
680+
masked_image = masked_image.to(device=device, dtype=dtype)
681+
676682
height = 2 * (int(height) // (self.vae_scale_factor * 2))
677683
width = 2 * (int(width) // (self.vae_scale_factor * 2))
678684
# resize the mask to latents shape as we concatenate the mask to the latents
679685
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
680686
# and half precision
681-
mask = torch.nn.functional.interpolate(mask, size=(height, width))
682-
mask = mask.to(device=device, dtype=dtype)
687+
mask_image = torch.nn.functional.interpolate(mask_image, size=(height, width))
688+
mask_image = mask_image.to(device=device, dtype=dtype)
683689

684690
batch_size = batch_size * num_images_per_prompt
685691

686692
masked_image = masked_image.to(device=device, dtype=dtype)
687693

688-
if masked_image.shape[1] == 16:
694+
if masked_image.shape[1] == num_channels_latents:
689695
masked_image_latents = masked_image
690696
else:
691697
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
692698

693699
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
694700

695701
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
696-
if mask.shape[0] < batch_size:
697-
if not batch_size % mask.shape[0] == 0:
702+
if mask_image.shape[0] < batch_size:
703+
if not batch_size % mask_image.shape[0] == 0:
698704
raise ValueError(
699705
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
700-
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
706+
f" a total batch size of {batch_size}, but {mask_image.shape[0]} mask_image were passed. Make sure the number"
701707
" of masks that you pass is divisible by the total requested batch size."
702708
)
703-
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
709+
mask_image = mask_image.repeat(batch_size // mask_image.shape[0], 1, 1, 1)
704710
if masked_image_latents.shape[0] < batch_size:
705711
if not batch_size % masked_image_latents.shape[0] == 0:
706712
raise ValueError(
@@ -719,15 +725,16 @@ def prepare_mask_latents(
719725
height,
720726
width,
721727
)
722-
mask = self._pack_latents(
723-
mask.repeat(1, num_channels_latents, 1, 1),
728+
mask_image = self._pack_latents(
729+
mask_image.repeat(1, num_channels_latents, 1, 1),
724730
batch_size,
725731
num_channels_latents,
726732
height,
727733
width,
728734
)
735+
masked_image_latents = torch.cat((masked_image_latents, mask_image), dim=-1)
729736

730-
return mask, masked_image_latents
737+
return mask_image, masked_image_latents
731738

732739
@property
733740
def guidance_scale(self):
@@ -759,7 +766,7 @@ def __call__(
759766
width: Optional[int] = None,
760767
strength: float = 0.6,
761768
num_inference_steps: int = 28,
762-
timesteps: List[int] = None,
769+
sigmas: Optional[List[float]] = None,
763770
guidance_scale: float = 7.0,
764771
num_images_per_prompt: Optional[int] = 1,
765772
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -820,10 +827,10 @@ def __call__(
820827
num_inference_steps (`int`, *optional*, defaults to 50):
821828
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
822829
expense of slower inference.
823-
timesteps (`List[int]`, *optional*):
824-
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
825-
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
826-
passed will be used. Must be in descending order.
830+
sigmas (`List[float]`, *optional*):
831+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
832+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
833+
will be used.
827834
guidance_scale (`float`, *optional*, defaults to 7.0):
828835
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
829836
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -927,18 +934,13 @@ def __call__(
927934
# 3. Preprocess mask and image
928935
num_channels_latents = self.vae.config.latent_channels
929936
if masked_image_latents is not None:
937+
# pre computed masked_image_latents and mask_image
930938
masked_image_latents = masked_image_latents.to(latents.device)
939+
mask = mask_image.to(latents.device)
931940
else:
932-
image = self.image_processor.preprocess(image, height=height, width=width)
933-
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
934-
935-
masked_image = image * (1 - mask_image)
936-
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
937-
938-
height, width = image.shape[-2:]
939941
mask, masked_image_latents = self.prepare_mask_latents(
942+
image,
940943
mask_image,
941-
masked_image,
942944
batch_size,
943945
num_channels_latents,
944946
num_images_per_prompt,
@@ -948,13 +950,12 @@ def __call__(
948950
device,
949951
generator,
950952
)
951-
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
952953

953954
init_image = self.image_processor.preprocess(image, height=height, width=width)
954955
init_image = init_image.to(dtype=torch.float32)
955956

956957
# 4.Prepare timesteps
957-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
958+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
958959
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
959960
mu = calculate_shift(
960961
image_seq_len,
@@ -967,8 +968,7 @@ def __call__(
967968
self.scheduler,
968969
num_inference_steps,
969970
device,
970-
timesteps,
971-
sigmas,
971+
sigmas=sigmas,
972972
mu=mu,
973973
)
974974
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
@@ -1062,11 +1062,12 @@ def __call__(
10621062
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
10631063

10641064
# for 64 channel transformer only.
1065-
init_latents_proper = image_latents
10661065
init_mask = mask
10671066
if i < len(timesteps) - 1:
10681067
noise_timestep = timesteps[i + 1]
1069-
init_latents_proper = self.scheduler.scale_noise(init_latents_proper, torch.tensor([noise_timestep]), noise)
1068+
init_latents_proper = self.scheduler.scale_noise(image_latents, torch.tensor([noise_timestep]), noise)
1069+
else:
1070+
init_latents_proper = image_latents
10701071
init_latents_proper = self._pack_latents(init_latents_proper, batch_size * num_images_per_prompt, num_channels_latents, height_8, width_8)
10711072

10721073
latents = (1 - init_mask) * init_latents_proper + init_mask * latents

0 commit comments

Comments
 (0)