Skip to content

Commit e318a67

Browse files
committed
return inversion outputs without self-assigning
1 parent 83e6db3 commit e318a67

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

examples/community/pipeline_flux_rf_inversion.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,9 @@ def interrupt(self):
588588
@replace_example_docstring(EXAMPLE_DOC_STRING)
589589
def __call__(
590590
self,
591+
latents: Optional[torch.FloatTensor] = None,
592+
image_latents: Optional[torch.FloatTensor] = None,
593+
latent_image_ids: Optional[torch.FloatTensor] = None,
591594
prompt: Union[str, List[str]] = None,
592595
prompt_2: Optional[Union[str, List[str]]] = None,
593596
height: Optional[int] = None,
@@ -601,7 +604,6 @@ def __call__(
601604
guidance_scale: float = 3.5,
602605
num_images_per_prompt: Optional[int] = 1,
603606
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
604-
latents: Optional[torch.FloatTensor] = None,
605607
prompt_embeds: Optional[torch.FloatTensor] = None,
606608
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
607609
output_type: Optional[str] = "pil",
@@ -735,9 +737,6 @@ def __call__(
735737

736738
# 4. Prepare latent variables
737739
num_channels_latents = self.transformer.config.in_channels // 4
738-
latents = self.inverted_latents
739-
latent_image_ids = self.latent_image_ids
740-
image_latents = self.image_latents
741740

742741
# 5. Prepare timesteps
743742
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
@@ -859,7 +858,6 @@ def invert(
859858
width: Optional[int] = None,
860859
timesteps: List[int] = None,
861860
dtype: Optional[torch.dtype] = None,
862-
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
863861
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
864862
):
865863
r"""
@@ -903,7 +901,6 @@ def invert(
903901
image_latents, latent_image_ids = self.prepare_latents(
904902
batch_size, num_channels_latents, height, width, dtype, device, image_latents
905903
)
906-
self.image_latents = image_latents.clone()
907904

908905
# 2. prepare timesteps
909906
sigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps)
@@ -974,7 +971,5 @@ def invert(
974971
Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1])
975972
progress_bar.update()
976973

977-
self.inverted_latents = Y_t
978-
self.latent_image_ids = latent_image_ids
979-
980-
return self.image_latents, Y_t, latent_image_ids
974+
# return the inverted latents (start point for the denoising loop), encoded image & latent image ids
975+
return Y_t, image_latents, latent_image_ids

0 commit comments

Comments
 (0)