Skip to content

Commit baae6e7

Browse files
committed
add decay
1 parent 6aaa268 commit baae6e7

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

examples/community/pipeline_flux_rf_inversion.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ def __call__(
648648
height: Optional[int] = None,
649649
width: Optional[int] = None,
650650
eta: float = 1.0,
651+
decay_eta: Optional[bool] = False,
652+
eta_decay_power: Optional[float] = 1.,
651653
strength: float = 1.0,
652654
start_timestep: float = 0,
653655
stop_timestep: float = 0.25,
@@ -881,8 +883,7 @@ def __call__(
881883
v_t_cond = (y_0 - latents) / (1 - t_i)
882884
eta_t = eta if start_timestep <= i < stop_timestep else 0.0
883885
if decay_eta:
884-
eta_t = eta_t * (1 - i / num_inference_steps) # Decay eta over the loop
885-
# eta_t = eta * (1 - i / num_inference_steps) ** 2
886+
eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loop
886887
v_hat_t = v_t + eta_t * (v_t_cond - v_t)
887888

888889
# SDE Eq: 17 from https://arxiv.org/pdf/2410.10792

0 commit comments

Comments
 (0)