Skip to content

Commit 6aaa268

Browse files
committed
add decay
1 parent ad40e26 commit 6aaa268

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

examples/community/pipeline_flux_rf_inversion.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -880,12 +880,10 @@ def __call__(
880880
v_t = -noise_pred
881881
v_t_cond = (y_0 - latents) / (1 - t_i)
882882
eta_t = eta if start_timestep <= i < stop_timestep else 0.0
883-
if start_timestep <= i < stop_timestep:
884-
# controlled vector field
885-
v_hat_t = v_t + eta * (v_t_cond - v_t)
886-
887-
else:
888-
v_hat_t = v_t
883+
if decay_eta:
884+
eta_t = eta_t * (1 - i / num_inference_steps) # Decay eta over the loop
885+
# eta_t = eta * (1 - i / num_inference_steps) ** 2
886+
v_hat_t = v_t + eta_t * (v_t_cond - v_t)
889887

890888
# SDE Eq: 17 from https://arxiv.org/pdf/2410.10792
891889
latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])

0 commit comments

Comments
 (0)