Skip to content

Commit 14c452a

Browse files
committed
[feat]add strength in flux_fill pipeline
1 parent aeac0a0 commit 14c452a

File tree

1 file changed

+115
-35
lines changed

1 file changed

+115
-35
lines changed

src/diffusers/pipelines/flux/pipeline_flux_fill.py

Lines changed: 115 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,9 @@ def __init__(
225225
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
226226
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
227227
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
228-
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
229228
self.mask_processor = VaeImageProcessor(
230229
vae_scale_factor=self.vae_scale_factor * 2,
231-
vae_latent_channels=latent_channels,
230+
vae_latent_channels=self.vae.config.latent_channels,
232231
do_normalize=False,
233232
do_binarize=True,
234233
do_convert_grayscale=True,
@@ -493,10 +492,40 @@ def encode_prompt(
493492

494493
return prompt_embeds, pooled_prompt_embeds, text_ids
495494

495+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
496+
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
497+
if isinstance(generator, list):
498+
image_latents = [
499+
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
500+
for i in range(image.shape[0])
501+
]
502+
image_latents = torch.cat(image_latents, dim=0)
503+
else:
504+
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
505+
506+
image_latents = (
507+
image_latents - self.vae.config.shift_factor
508+
) * self.vae.config.scaling_factor
509+
510+
return image_latents
511+
512+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
513+
def get_timesteps(self, num_inference_steps, strength, device):
514+
# get the original timestep using init_timestep
515+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
516+
517+
t_start = int(max(num_inference_steps - init_timestep, 0))
518+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
519+
if hasattr(self.scheduler, "set_begin_index"):
520+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
521+
522+
return timesteps, num_inference_steps - t_start
523+
496524
def check_inputs(
497525
self,
498526
prompt,
499527
prompt_2,
528+
strength,
500529
height,
501530
width,
502531
prompt_embeds=None,
@@ -507,6 +536,9 @@ def check_inputs(
507536
mask_image=None,
508537
masked_image_latents=None,
509538
):
539+
if strength < 0 or strength > 1:
540+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
541+
510542
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
511543
logger.warning(
512544
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
@@ -627,6 +659,8 @@ def disable_vae_tiling(self):
627659
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
628660
def prepare_latents(
629661
self,
662+
image,
663+
timestep,
630664
batch_size,
631665
num_channels_latents,
632666
height,
@@ -643,22 +677,37 @@ def prepare_latents(
643677

644678
shape = (batch_size, num_channels_latents, height, width)
645679

646-
if latents is not None:
647-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
648-
return latents.to(device=device, dtype=dtype), latent_image_ids
680+
# if latents is not None:
681+
image = image.to(device=device, dtype=dtype)
682+
image_latents = self._encode_vae_image(image=image, generator=generator)
649683

650-
if isinstance(generator, list) and len(generator) != batch_size:
684+
latent_image_ids = self._prepare_latent_image_ids(
685+
batch_size, height // 2, width // 2, device, dtype
686+
)
687+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
688+
# expand init_latents for batch_size
689+
additional_image_per_prompt = batch_size // image_latents.shape[0]
690+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
691+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
651692
raise ValueError(
652-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
653-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
693+
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
654694
)
695+
else:
696+
image_latents = torch.cat([image_latents], dim=0)
655697

656-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
657-
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
658-
659-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
698+
if latents is None:
699+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
700+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
701+
else:
702+
noise = latents.to(device)
703+
latents = noise
660704

661-
return latents, latent_image_ids
705+
noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
706+
image_latents = self._pack_latents(
707+
image_latents, batch_size, num_channels_latents, height, width
708+
)
709+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
710+
return latents, noise, image_latents, latent_image_ids
662711

663712
@property
664713
def guidance_scale(self):
@@ -687,6 +736,7 @@ def __call__(
687736
masked_image_latents: Optional[torch.FloatTensor] = None,
688737
height: Optional[int] = None,
689738
width: Optional[int] = None,
739+
strength: float = 1.0,
690740
num_inference_steps: int = 50,
691741
sigmas: Optional[List[float]] = None,
692742
guidance_scale: float = 30.0,
@@ -731,6 +781,12 @@ def __call__(
731781
The height in pixels of the generated image. This is set to 1024 by default for the best results.
732782
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
733783
The width in pixels of the generated image. This is set to 1024 by default for the best results.
784+
strength (`float`, *optional*, defaults to 1.0):
785+
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
786+
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
787+
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
788+
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
789+
essentially ignores `image`.
734790
num_inference_steps (`int`, *optional*, defaults to 50):
735791
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
736792
expense of slower inference.
@@ -794,6 +850,7 @@ def __call__(
794850
self.check_inputs(
795851
prompt,
796852
prompt_2,
853+
strength,
797854
height,
798855
width,
799856
prompt_embeds=prompt_embeds,
@@ -809,6 +866,10 @@ def __call__(
809866
self._joint_attention_kwargs = joint_attention_kwargs
810867
self._interrupt = False
811868

869+
original_image = image
870+
init_image = self.image_processor.preprocess(image, height=height, width=width)
871+
init_image = init_image.to(dtype=torch.float32)
872+
812873
# 2. Define call parameters
813874
if prompt is not None and isinstance(prompt, str):
814875
batch_size = 1
@@ -821,7 +882,9 @@ def __call__(
821882

822883
# 3. Prepare prompt embeddings
823884
lora_scale = (
824-
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
885+
self.joint_attention_kwargs.get("scale", None)
886+
if self.joint_attention_kwargs is not None
887+
else None
825888
)
826889
(
827890
prompt_embeds,
@@ -838,9 +901,43 @@ def __call__(
838901
lora_scale=lora_scale,
839902
)
840903

904+
# 6. Prepare timesteps
905+
sigmas = (
906+
np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
907+
if sigmas is None
908+
else sigmas
909+
)
910+
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (
911+
int(width) // self.vae_scale_factor // 2
912+
)
913+
mu = calculate_shift(
914+
image_seq_len,
915+
self.scheduler.config.base_image_seq_len,
916+
self.scheduler.config.max_image_seq_len,
917+
self.scheduler.config.base_shift,
918+
self.scheduler.config.max_shift,
919+
)
920+
timesteps, num_inference_steps = retrieve_timesteps(
921+
self.scheduler,
922+
num_inference_steps,
923+
device,
924+
sigmas=sigmas,
925+
mu=mu,
926+
)
927+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
928+
929+
if num_inference_steps < 1:
930+
raise ValueError(
931+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
932+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
933+
)
934+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
935+
841936
# 4. Prepare latent variables
842937
num_channels_latents = self.vae.config.latent_channels
843-
latents, latent_image_ids = self.prepare_latents(
938+
latents, noise, image_latents, latent_image_ids = self.prepare_latents(
939+
init_image,
940+
latent_timestep,
844941
batch_size * num_images_per_prompt,
845942
num_channels_latents,
846943
height,
@@ -855,13 +952,13 @@ def __call__(
855952
if masked_image_latents is not None:
856953
masked_image_latents = masked_image_latents.to(latents.device)
857954
else:
858-
image = self.image_processor.preprocess(image, height=height, width=width)
955+
# image = self.image_processor.preprocess(image, height=height, width=width)
859956
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
860957

861-
masked_image = image * (1 - mask_image)
958+
masked_image = init_image * (1 - mask_image)
862959
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
863960

864-
height, width = image.shape[-2:]
961+
height, width = init_image.shape[-2:]
865962
mask, masked_image_latents = self.prepare_mask_latents(
866963
mask_image,
867964
masked_image,
@@ -876,23 +973,6 @@ def __call__(
876973
)
877974
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
878975

879-
# 6. Prepare timesteps
880-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
881-
image_seq_len = latents.shape[1]
882-
mu = calculate_shift(
883-
image_seq_len,
884-
self.scheduler.config.get("base_image_seq_len", 256),
885-
self.scheduler.config.get("max_image_seq_len", 4096),
886-
self.scheduler.config.get("base_shift", 0.5),
887-
self.scheduler.config.get("max_shift", 1.16),
888-
)
889-
timesteps, num_inference_steps = retrieve_timesteps(
890-
self.scheduler,
891-
num_inference_steps,
892-
device,
893-
sigmas=sigmas,
894-
mu=mu,
895-
)
896976
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
897977
self._num_timesteps = len(timesteps)
898978

0 commit comments

Comments
 (0)