File tree Expand file tree Collapse file tree 1 file changed +5
-6
lines changed Expand file tree Collapse file tree 1 file changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -648,6 +648,8 @@ def __call__(
648648        height : Optional [int ] =  None ,
649649        width : Optional [int ] =  None ,
650650        eta : float  =  1.0 ,
651+         decay_eta : Optional [bool ] =  False ,
652+         eta_decay_power : Optional [float ] =  1.0 ,
651653        strength : float  =  1.0 ,
652654        start_timestep : float  =  0 ,
653655        stop_timestep : float  =  0.25 ,
@@ -880,12 +882,9 @@ def __call__(
880882                    v_t  =  - noise_pred 
881883                    v_t_cond  =  (y_0  -  latents ) /  (1  -  t_i )
882884                    eta_t  =  eta  if  start_timestep  <=  i  <  stop_timestep  else  0.0 
883-                     if  start_timestep  <=  i  <  stop_timestep :
884-                         # controlled vector field 
885-                         v_hat_t  =  v_t  +  eta  *  (v_t_cond  -  v_t )
886- 
887-                     else :
888-                         v_hat_t  =  v_t 
885+                     if  decay_eta :
886+                         eta_t  =  eta_t  *  (1  -  i  /  num_inference_steps ) **  eta_decay_power   # Decay eta over the loop 
887+                     v_hat_t  =  v_t  +  eta_t  *  (v_t_cond  -  v_t )
889888
890889                    # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792 
891890                    latents  =  latents  +  v_hat_t  *  (sigmas [i ] -  sigmas [i  +  1 ])
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments