@@ -1557,10 +1557,13 @@ def default_er_sde_noise_scaler(x):
15571557
15581558
15591559@torch .no_grad ()
1560- def sample_seeds_2 (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , r = 0.5 ):
1560+ def sample_seeds_2 (model , x , sigmas , extra_args = None , callback = None , disable = None , eta = 1. , s_noise = 1. , noise_sampler = None , r = 0.5 , solver_type = "phi_1" ):
15611561 """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
15621562 arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
15631563 """
1564+ if solver_type not in {"phi_1" , "phi_2" }:
1565+ raise ValueError ("solver_type must be 'phi_1' or 'phi_2'" )
1566+
15641567 extra_args = {} if extra_args is None else extra_args
15651568 seed = extra_args .get ("seed" , None )
15661569 noise_sampler = default_noise_sampler (x , seed = seed ) if noise_sampler is None else noise_sampler
@@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
16001603 denoised_2 = model (x_2 , sigma_s_1 * s_in , ** extra_args )
16011604
16021605 # Step 2
1603- denoised_d = torch .lerp (denoised , denoised_2 , fac )
1604- x = sigmas [i + 1 ] / sigmas [i ] * (- h * eta ).exp () * x - alpha_t * ei_h_phi_1 (- h_eta ) * denoised_d
1606+ if solver_type == "phi_1" :
1607+ denoised_d = torch .lerp (denoised , denoised_2 , fac )
1608+ x = sigmas [i + 1 ] / sigmas [i ] * (- h * eta ).exp () * x - alpha_t * ei_h_phi_1 (- h_eta ) * denoised_d
1609+ elif solver_type == "phi_2" :
1610+ b2 = ei_h_phi_2 (- h_eta ) / r
1611+ b1 = ei_h_phi_1 (- h_eta ) - b2
1612+ x = sigmas [i + 1 ] / sigmas [i ] * (- h * eta ).exp () * x - alpha_t * (b1 * denoised + b2 * denoised_2 )
1613+
16051614 if inject_noise :
16061615 segment_factor = (r - 1 ) * h * eta
16071616 sde_noise = sde_noise * segment_factor .exp ()
0 commit comments