@@ -1618,6 +1618,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
16181618 x = x + sde_noise * sigmas [i + 1 ] * s_noise
16191619 return x
16201620
1621+ @torch .no_grad ()
1622+ def sample_exp_heun_2_x0 (model , x , sigmas , extra_args = None , callback = None , disable = None , solver_type = "phi_2" ):
1623+ """Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
1624+ return sample_seeds_2 (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , eta = 0.0 , s_noise = 0.0 , noise_sampler = None , r = 1.0 , solver_type = solver_type )
1625+
1626+
1627+ @torch .no_grad ()
1628+ def sample_exp_heun_2_x0_sde (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , solver_type = "phi_2" ):
1629+ """Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
1630+ return sample_seeds_2 (model , x , sigmas , extra_args = extra_args , callback = callback , disable = disable , eta = eta , s_noise = s_noise , noise_sampler = noise_sampler , r = 1.0 , solver_type = solver_type )
1631+
16211632
16221633@torch .no_grad ()
16231634def sample_seeds_3 (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , r_1 = 1. / 3 , r_2 = 2. / 3 ):
@@ -1765,7 +1776,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
17651776 # Predictor
17661777 if sigmas [i + 1 ] == 0 :
17671778 # Denoising step
1768- x = denoised
1779+ x_pred = denoised
17691780 else :
17701781 tau_t = tau_func (sigmas [i + 1 ])
17711782 curr_lambdas = lambdas [i - predictor_order_used + 1 :i + 1 ]
@@ -1786,7 +1797,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
17861797 if tau_t > 0 and s_noise > 0 :
17871798 noise = noise_sampler (sigmas [i ], sigmas [i + 1 ]) * sigmas [i + 1 ] * (- 2 * tau_t ** 2 * h ).expm1 ().neg ().sqrt () * s_noise
17881799 x_pred = x_pred + noise
1789- return x
1800+ return x_pred
17901801
17911802
17921803@torch .no_grad ()
0 commit comments