Skip to content

Commit d523e2b

Browse files
committed
adjust denoising loop to generate regular images if inverted latents are not provided
1 parent e318a67 commit d523e2b

File tree

1 file changed

+86
-17
lines changed

1 file changed

+86
-17
lines changed

examples/community/pipeline_flux_rf_inversion.py

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,9 @@ def check_inputs(
419419
self,
420420
prompt,
421421
prompt_2,
422+
inverted_latents,
423+
image_latents,
424+
latent_image_ids,
422425
height,
423426
width,
424427
start_timestep,
@@ -467,6 +470,10 @@ def check_inputs(
467470
if max_sequence_length is not None and max_sequence_length > 512:
468471
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
469472

473+
if inverted_latents is not None and (image_latents is None or latent_image_ids is None):
474+
raise ValueError(
475+
"If `inverted_latents` are provided, `image_latents` and `latent_image_ids` also have to be passed. "
476+
)
470477
# check start_timestep and stop_timestep
471478
if start_timestep < 0 or start_timestep > stop_timestep:
472479
raise ValueError(f"`start_timestep` should be in [0, stop_timestep] but is {start_timestep}")
@@ -536,7 +543,7 @@ def disable_vae_tiling(self):
536543
"""
537544
self.vae.disable_tiling()
538545

539-
def prepare_latents(
546+
def prepare_latents_inversion(
540547
self,
541548
batch_size,
542549
num_channels_latents,
@@ -555,6 +562,41 @@ def prepare_latents(
555562

556563
return latents, latent_image_ids
557564

565+
def prepare_latents(
566+
self,
567+
batch_size,
568+
num_channels_latents,
569+
height,
570+
width,
571+
dtype,
572+
device,
573+
generator,
574+
latents=None,
575+
):
576+
# VAE applies 8x compression on images but we must also account for packing which requires
577+
# latent height and width to be divisible by 2.
578+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
579+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
580+
581+
shape = (batch_size, num_channels_latents, height, width)
582+
583+
if latents is not None:
584+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
585+
return latents.to(device=device, dtype=dtype), latent_image_ids
586+
587+
if isinstance(generator, list) and len(generator) != batch_size:
588+
raise ValueError(
589+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
590+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
591+
)
592+
593+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
594+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
595+
596+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
597+
598+
return latents, latent_image_ids
599+
558600
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
559601
def get_timesteps(self, num_inference_steps, strength=1.0):
560602
# get the original timestep using init_timestep
@@ -588,11 +630,11 @@ def interrupt(self):
588630
@replace_example_docstring(EXAMPLE_DOC_STRING)
589631
def __call__(
590632
self,
591-
latents: Optional[torch.FloatTensor] = None,
592-
image_latents: Optional[torch.FloatTensor] = None,
593-
latent_image_ids: Optional[torch.FloatTensor] = None,
594633
prompt: Union[str, List[str]] = None,
595634
prompt_2: Optional[Union[str, List[str]]] = None,
635+
inverted_latents: Optional[torch.FloatTensor] = None,
636+
image_latents: Optional[torch.FloatTensor] = None,
637+
latent_image_ids: Optional[torch.FloatTensor] = None,
596638
height: Optional[int] = None,
597639
width: Optional[int] = None,
598640
eta: float = 1.0,
@@ -604,6 +646,7 @@ def __call__(
604646
guidance_scale: float = 3.5,
605647
num_images_per_prompt: Optional[int] = 1,
606648
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
649+
latents: Optional[torch.FloatTensor] = None,
607650
prompt_embeds: Optional[torch.FloatTensor] = None,
608651
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
609652
output_type: Optional[str] = "pil",
@@ -693,6 +736,9 @@ def __call__(
693736
self.check_inputs(
694737
prompt,
695738
prompt_2,
739+
inverted_latents,
740+
image_latents,
741+
latent_image_ids,
696742
height,
697743
width,
698744
start_timestep,
@@ -706,6 +752,7 @@ def __call__(
706752
self._guidance_scale = guidance_scale
707753
self._joint_attention_kwargs = joint_attention_kwargs
708754
self._interrupt = False
755+
do_rf_inversion = inverted_latents is not None
709756

710757
# 2. Define call parameters
711758
if prompt is not None and isinstance(prompt, str):
@@ -737,6 +784,19 @@ def __call__(
737784

738785
# 4. Prepare latent variables
739786
num_channels_latents = self.transformer.config.in_channels // 4
787+
if do_rf_inversion:
788+
latents = inverted_latents
789+
else:
790+
latents, latent_image_ids = self.prepare_latents(
791+
batch_size * num_images_per_prompt,
792+
num_channels_latents,
793+
height,
794+
width,
795+
prompt_embeds.dtype,
796+
device,
797+
generator,
798+
latents,
799+
)
740800

741801
# 5. Prepare timesteps
742802
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
@@ -769,9 +829,11 @@ def __call__(
769829
else:
770830
guidance = None
771831

832+
if do_rf_inversion:
833+
y_0 = image_latents.clone()
772834
# 6. Denoising loop
773835
with self.progress_bar(total=num_inference_steps) as progress_bar:
774-
y_0 = image_latents.clone()
836+
775837
for i, t in enumerate(timesteps):
776838
t_i = 1 - t / 1000
777839
dt = torch.tensor(1 / (len(timesteps) - 1), device=device)
@@ -782,7 +844,7 @@ def __call__(
782844
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
783845
timestep = t.expand(latents.shape[0]).to(latents.dtype)
784846

785-
v_t = -self.transformer(
847+
noise_pred = self.transformer(
786848
hidden_states=latents,
787849
timestep=timestep / 1000,
788850
guidance=guidance,
@@ -794,18 +856,25 @@ def __call__(
794856
return_dict=False,
795857
)[0]
796858

797-
v_t_cond = (y_0 - latents) / (1 - t_i)
798-
eta_t = eta if start_timestep <= i < stop_timestep else 0.0
799-
if start_timestep <= i < stop_timestep:
800-
# controlled vector field
801-
v_hat_t = v_t + eta * (v_t_cond - v_t)
859+
if do_rf_inversion:
860+
v_t = -noise_pred
802861

803-
else:
804-
v_hat_t = v_t
805-
# SDE Eq: 17
862+
v_t_cond = (y_0 - latents) / (1 - t_i)
863+
eta_t = eta if start_timestep <= i < stop_timestep else 0.0
864+
if start_timestep <= i < stop_timestep:
865+
# controlled vector field
866+
v_hat_t = v_t + eta * (v_t_cond - v_t)
806867

807-
latents_dtype = latents.dtype
808-
latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])
868+
else:
869+
v_hat_t = v_t
870+
# SDE Eq: 17
871+
872+
latents_dtype = latents.dtype
873+
latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])
874+
else:
875+
# compute the previous noisy sample x_t -> x_t-1
876+
latents_dtype = latents.dtype
877+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
809878

810879
if latents.dtype != latents_dtype:
811880
if torch.backends.mps.is_available():
@@ -898,7 +967,7 @@ def invert(
898967

899968
# 1. prepare image
900969
image_latents, _ = self.encode_image(image, height=height, width=width, dtype=dtype)
901-
image_latents, latent_image_ids = self.prepare_latents(
970+
image_latents, latent_image_ids = self.prepare_latents_inversion(
902971
batch_size, num_channels_latents, height, width, dtype, device, image_latents
903972
)
904973

0 commit comments

Comments
 (0)