Skip to content

Commit 6df3f2c

Browse files
committed
maybe
1 parent 462a02c commit 6df3f2c

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/diffusers/pipelines/flux/pipeline_flux_rfinversion_noise.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,8 @@ def __call__(
627627
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
628628
max_sequence_length: int = 512,
629629
sigmas = None,
630+
flip_schedule = False,
631+
even_timesteps = None,
630632
):
631633
r"""
632634
Function invoked when calling the pipeline for generation.
@@ -781,6 +783,11 @@ def __call__(
781783
sigmas,
782784
mu=mu,
783785
)
786+
if flip_schedule:
787+
self.scheduler.sigmas = self.scheduler.sigmas.flip(0)
788+
self.scheduler.timesteps = self.scheduler.timesteps.flip(0)
789+
print(f"self.scheduler.sigmas {self.scheduler.sigmas}")
790+
print(f"self.scheduler.timesteps {self.scheduler.timesteps}")
784791
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
785792

786793
if num_inference_steps < 1:
@@ -827,8 +834,11 @@ def __call__(
827834
continue
828835
if self.interrupt:
829836
continue
830-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
831-
timestep = timestep / 1000
837+
if even_timesteps is None:
838+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
839+
timestep = timestep / 1000
840+
else:
841+
timestep = even_timesteps[i]
832842
noise_pred = self.transformer(
833843
hidden_states=latents,
834844
timestep=timestep,

0 commit comments

Comments
 (0)