Skip to content

Commit 8708d60

Browse files
committed
maybe
1 parent 6df3f2c commit 8708d60

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_edit.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,10 @@ def __call__(
627627
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
628628
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
629629
max_sequence_length: int = 512,
630+
sigmas = None,
631+
flip_schedule = False,
632+
even_timesteps = None,
633+
divide_timestep = True,
630634
):
631635
r"""
632636
Function invoked when calling the pipeline for generation.
@@ -763,8 +767,6 @@ def __call__(
763767
)
764768

765769
# 4.Prepare timesteps
766-
# Flux noise scheduler $\sigma : [0, 1] \to \mathbb{R}$
767-
sigmas = np.linspace(0.0, 1.0, num_inference_steps)
768770
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
769771
mu = calculate_shift(
770772
image_seq_len,
@@ -781,6 +783,11 @@ def __call__(
781783
sigmas,
782784
mu=mu,
783785
)
786+
if flip_schedule:
787+
self.scheduler.sigmas = self.scheduler.sigmas.flip(0)
788+
self.scheduler.timesteps = self.scheduler.timesteps.flip(0)
789+
print(f"self.scheduler.sigmas {self.scheduler.sigmas}")
790+
print(f"self.scheduler.timesteps {self.scheduler.timesteps}")
784791
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
785792

786793
if num_inference_steps < 1:
@@ -837,13 +844,17 @@ def __call__(
837844
if self.interrupt:
838845
continue
839846
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
840-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
841-
timestep = timestep / 1000
847+
if even_timesteps is None:
848+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
849+
if divide_timestep:
850+
timestep = timestep / 1000
851+
else:
852+
timestep = torch.tensor([even_timesteps[i]], device=latents.device, dtype=latents.dtype)
842853
# Unconditional vector field: $v_{t_i}(X_{t_i}) = -u(X_{t_i}, 1 - t_i, \Phi(\text{prompt}); \phi)$
843854
timestep = 1-timestep
844855
unconditional_vector_field = -self.transformer(
845856
hidden_states=latents,
846-
timestep=timestep,
857+
timestep=timestep if divide_timestep else timestep / 1000,
847858
guidance=guidance,
848859
pooled_projections=pooled_prompt_embeds,
849860
encoder_hidden_states=prompt_embeds,

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_noise.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ def __call__(
629629
sigmas = None,
630630
flip_schedule = False,
631631
even_timesteps = None,
632+
divide_timestep = True
632633
):
633634
r"""
634635
Function invoked when calling the pipeline for generation.
@@ -836,12 +837,13 @@ def __call__(
836837
continue
837838
if even_timesteps is None:
838839
timestep = t.expand(latents.shape[0]).to(latents.dtype)
839-
timestep = timestep / 1000
840+
if divide_timestep:
841+
timestep = timestep / 1000
840842
else:
841-
timestep = even_timesteps[i]
843+
timestep = torch.tensor([even_timesteps[i]], device=latents.device, dtype=latents.dtype)
842844
noise_pred = self.transformer(
843845
hidden_states=latents,
844-
timestep=timestep,
846+
timestep=timestep if divide_timestep else timestep / 1000,
845847
guidance=guidance,
846848
pooled_projections=pooled_prompt_embeds,
847849
encoder_hidden_states=prompt_embeds,

0 commit comments

Comments
 (0)