File tree Expand file tree Collapse file tree 1 file changed +5
-6
lines changed Expand file tree Collapse file tree 1 file changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -648,6 +648,8 @@ def __call__(
648648 height : Optional [int ] = None ,
649649 width : Optional [int ] = None ,
650650 eta : float = 1.0 ,
651+ decay_eta : Optional [bool ] = False ,
652+ eta_decay_power : Optional [float ] = 1.0 ,
651653 strength : float = 1.0 ,
652654 start_timestep : float = 0 ,
653655 stop_timestep : float = 0.25 ,
@@ -880,12 +882,9 @@ def __call__(
880882 v_t = - noise_pred
881883 v_t_cond = (y_0 - latents ) / (1 - t_i )
882884 eta_t = eta if start_timestep <= i < stop_timestep else 0.0
883- if start_timestep <= i < stop_timestep :
884- # controlled vector field
885- v_hat_t = v_t + eta * (v_t_cond - v_t )
886-
887- else :
888- v_hat_t = v_t
885+ if decay_eta :
886+ eta_t = eta_t * (1 - i / num_inference_steps ) ** eta_decay_power # Decay eta over the loop
887+ v_hat_t = v_t + eta_t * (v_t_cond - v_t )
889888
890889 # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792
891890 latents = latents + v_hat_t * (sigmas [i ] - sigmas [i + 1 ])
You can’t perform that action at this time.
0 commit comments