@@ -835,4 +835,74 @@ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, n
835835 else :
836836 x *= torch .sqrt (1.0 + sigmas [i + 1 ] ** 2 )
837837
838+ return x
839+
840+
841+ @torch .no_grad ()
842+ def sample_restart (model , x , sigmas , extra_args = None , callback = None , disable = None , s_noise = 1. , restart_list = None ):
843+ """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
844+ Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
845+ If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
846+ """
847+ extra_args = {} if extra_args is None else extra_args
848+ s_in = x .new_ones ([x .shape [0 ]])
849+ step_id = 0
850+
851+ def heun_step (x , old_sigma , new_sigma , second_order = True ):
852+ nonlocal step_id
853+ denoised = model (x , old_sigma * s_in , ** extra_args )
854+ d = to_d (x , old_sigma , denoised )
855+ if callback is not None :
856+ callback ({'x' : x , 'i' : step_id , 'sigma' : new_sigma , 'sigma_hat' : old_sigma , 'denoised' : denoised })
857+ dt = new_sigma - old_sigma
858+ if new_sigma == 0 or not second_order :
859+ # Euler method
860+ x = x + d * dt
861+ else :
862+ # Heun's method
863+ x_2 = x + d * dt
864+ denoised_2 = model (x_2 , new_sigma * s_in , ** extra_args )
865+ d_2 = to_d (x_2 , new_sigma , denoised_2 )
866+ d_prime = (d + d_2 ) / 2
867+ x = x + d_prime * dt
868+ step_id += 1
869+ return x
870+
871+ steps = sigmas .shape [0 ] - 1
872+ if restart_list is None :
873+ if steps >= 20 :
874+ restart_steps = 9
875+ restart_times = 1
876+ if steps >= 36 :
877+ restart_steps = steps // 4
878+ restart_times = 2
879+ sigmas = get_sigmas_karras (steps - restart_steps * restart_times , sigmas [- 2 ].item (), sigmas [0 ].item (), device = sigmas .device )
880+ restart_list = {0.1 : [restart_steps + 1 , restart_times , 2 ]}
881+ else :
882+ restart_list = {}
883+
884+ restart_list = {int (torch .argmin (abs (sigmas - key ), dim = 0 )): value for key , value in restart_list .items ()}
885+
886+ step_list = []
887+ for i in range (len (sigmas ) - 1 ):
888+ step_list .append ((sigmas [i ], sigmas [i + 1 ]))
889+ if i + 1 in restart_list :
890+ restart_steps , restart_times , restart_max = restart_list [i + 1 ]
891+ min_idx = i + 1
892+ max_idx = int (torch .argmin (abs (sigmas - restart_max ), dim = 0 ))
893+ if max_idx < min_idx :
894+ sigma_restart = get_sigmas_karras (restart_steps , sigmas [min_idx ].item (), sigmas [max_idx ].item (), device = sigmas .device )[:- 1 ]
895+ while restart_times > 0 :
896+ restart_times -= 1
897+ step_list .extend (zip (sigma_restart [:- 1 ], sigma_restart [1 :]))
898+
899+ last_sigma = None
900+ for old_sigma , new_sigma in tqdm (step_list , disable = disable ):
901+ if last_sigma is None :
902+ last_sigma = old_sigma
903+ elif last_sigma < old_sigma :
904+ x = x + torch .randn_like (x ) * s_noise * (old_sigma ** 2 - last_sigma ** 2 ) ** 0.5
905+ x = heun_step (x , old_sigma , new_sigma )
906+ last_sigma = new_sigma
907+
838908 return x
0 commit comments