@@ -394,8 +394,8 @@ def data_prediction_fn(self, x, t):
394394 if self .thresholding :
395395 p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
396396 s = torch .quantile (torch .abs (x0 ).reshape ((x0 .shape [0 ], - 1 )), p , dim = 1 )
397- s = expand_dims (torch .maximum (s , torch .ones_like (s ).to (s .device )), dims )
398- x0 = torch .clamp (x0 , - s , s ) / ( s / self . max_val )
397+ s = expand_dims (torch .maximum (s , self . max_val * torch .ones_like (s ).to (s .device )), dims )
398+ x0 = torch .clamp (x0 , - s , s ) / s
399399 return x0
400400
401401 def model_fn (self , x , t ):
@@ -436,7 +436,7 @@ def get_time_steps(self, skip_type, t_T, t_0, N, device):
436436 else :
437437 raise ValueError ("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'" .format (skip_type ))
438438
439- def get_orders_for_singlestep_solver (self , steps , order ):
439+ def get_orders_and_timesteps_for_singlestep_solver (self , steps , order , skip_type , t_T , t_0 , device ):
440440 """
441441 Get the order of each step for sampling by the singlestep DPM-Solver.
442442
@@ -458,6 +458,13 @@ def get_orders_for_singlestep_solver(self, steps, order):
458458 Args:
459459 order: A `int`. The max order for the solver (2 or 3).
460460 steps: A `int`. The total number of function evaluations (NFE).
461+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
462+ - 'logSNR': uniform logSNR for the time steps.
463+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
464+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
465+ t_T: A `float`. The starting time of the sampling (default is T).
466+ t_0: A `float`. The ending time of the sampling (default is epsilon).
467+ device: A torch device.
461468 Returns:
462469 orders: A list of the solver order of each step.
463470 """
@@ -469,20 +476,26 @@ def get_orders_for_singlestep_solver(self, steps, order):
469476 orders = [3 ,] * (K - 1 ) + [1 ]
470477 else :
471478 orders = [3 ,] * (K - 1 ) + [2 ]
472- return orders
473479 elif order == 2 :
474- K = steps // 2
475480 if steps % 2 == 0 :
481+ K = steps // 2
476482 orders = [2 ,] * K
477483 else :
478- orders = [ 2 ,] * K + [ 1 ]
479- return orders
484+ K = steps // 2 + 1
485+ orders = [ 2 ,] * ( K - 1 ) + [ 1 ]
480486 elif order == 1 :
481- return [1 ,] * steps
487+ K = 1
488+ orders = [1 ,] * steps
482489 else :
483490 raise ValueError ("'order' must be '1' or '2' or '3'." )
491+ if skip_type == 'logSNR' :
492+ # To reproduce the results in DPM-Solver paper
493+ timesteps_outer = self .get_time_steps (skip_type , t_T , t_0 , K , device )
494+ else :
495+ timesteps_outer = self .get_time_steps (skip_type , t_T , t_0 , steps , device )[torch .cumsum (torch .tensor ([0 ,] + orders )).to (device )]
496+ return timesteps_outer , orders
484497
485- def denoise_fn (self , x , s ):
498+ def denoise_to_zero_fn (self , x , s ):
486499 """
487500 Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
488501 """
@@ -950,8 +963,8 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol
950963 return x
951964
952965 def sample (self , x , steps = 20 , t_start = None , t_end = None , order = 3 , skip_type = 'time_uniform' ,
953- method = 'singlestep' , denoise = False , solver_type = 'dpm_solver' , atol = 0.0078 ,
954- rtol = 0.05 ,
966+ method = 'singlestep' , lower_order_final = True , denoise_to_zero = False , solver_type = 'dpm_solver' ,
967+ atol = 0.0078 , rtol = 0.05 ,
955968 ):
956969 """
957970 Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
@@ -1035,8 +1048,19 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time
10351048 order: A `int`. The order of DPM-Solver.
10361049 skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
10371050 method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1038- denoise: A `bool`. Whether to denoise at the final step. Default is False.
1039- If `denoise` is True, the total NFE is (`steps` + 1).
1051+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1052+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1053+
1054+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1055+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1056+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1057+ (such as CIFAR-10). However, we observed that such trick does not matter for
1058+ high-resolutional images. As it needs an additional NFE, we do not recommend
1059+ it for high-resolutional images.
1060+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1061+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1062+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1063+ (especially for steps <= 10). So we recommend to set it to be `True`.
10401064 solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
10411065 atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
10421066 rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
@@ -1067,7 +1091,11 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time
10671091 # Compute the remaining values by `order`-th order multistep DPM-Solver.
10681092 for step in range (order , steps + 1 ):
10691093 vec_t = timesteps [step ].expand (x .shape [0 ])
1070- x = self .multistep_dpm_solver_update (x , model_prev_list , t_prev_list , vec_t , order , solver_type = solver_type )
1094+ if lower_order_final and steps < 15 :
1095+ step_order = min (order , steps + 1 - step )
1096+ else :
1097+ step_order = order
1098+ x = self .multistep_dpm_solver_update (x , model_prev_list , t_prev_list , vec_t , step_order , solver_type = solver_type )
10711099 for i in range (order - 1 ):
10721100 t_prev_list [i ] = t_prev_list [i + 1 ]
10731101 model_prev_list [i ] = model_prev_list [i + 1 ]
@@ -1077,23 +1105,22 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time
10771105 model_prev_list [- 1 ] = self .model_fn (x , vec_t )
10781106 elif method in ['singlestep' , 'singlestep_fixed' ]:
10791107 if method == 'singlestep' :
1080- orders = self .get_orders_for_singlestep_solver (steps = steps , order = order )
1081- timesteps = self .get_time_steps (skip_type = skip_type , t_T = t_T , t_0 = t_0 , N = steps , device = device )
1108+ timesteps_outer , orders = self .get_orders_and_timesteps_for_singlestep_solver (steps = steps , order = order , skip_type = skip_type , t_T = t_T , t_0 = t_0 , device = device )
10821109 elif method == 'singlestep_fixed' :
10831110 K = steps // order
10841111 orders = [order ,] * K
1085- timesteps = self .get_time_steps (skip_type = skip_type , t_T = t_T , t_0 = t_0 , N = ( K * order ) , device = device )
1086- with torch . no_grad ( ):
1087- i = 0
1088- for order in orders :
1089- vec_s , vec_t = timesteps [ i ]. expand ( x . shape [ 0 ]), timesteps [ i + order ]. expand ( x . shape [ 0 ] )
1090- h = self . noise_schedule . marginal_lambda ( timesteps [ i + order ]) - self . noise_schedule . marginal_lambda ( timesteps [ i ])
1091- r1 = None if order <= 1 else ( self . noise_schedule . marginal_lambda ( timesteps [ i + 1 ]) - self . noise_schedule . marginal_lambda ( timesteps [ i ])) / h
1092- r2 = None if order <= 2 else (self . noise_schedule . marginal_lambda ( timesteps [ i + 2 ]) - self . noise_schedule . marginal_lambda ( timesteps [ i ]) ) / h
1093- x = self . singlestep_dpm_solver_update ( x , vec_s , vec_t , order , solver_type = solver_type , r1 = r1 , r2 = r2 )
1094- i += order
1095- if denoise :
1096- x = self .denoise_fn (x , torch .ones ((x .shape [0 ],)).to (device ) * t_0 )
1112+ timesteps_outer = self .get_time_steps (skip_type = skip_type , t_T = t_T , t_0 = t_0 , N = K , device = device )
1113+ for i , order in enumerate ( orders ):
1114+ t_T_inner , t_0_inner = timesteps_outer [ i ], timesteps_outer [ i + 1 ]
1115+ timesteps_inner = self . get_time_steps ( skip_type = skip_type , t_T = t_T_inner . item (), t_0 = t_0_inner . item (), N = order , device = device )
1116+ lambda_inner = self . noise_schedule . marginal_lambda ( timesteps_inner )
1117+ vec_s , vec_t = t_T_inner . tile ( x . shape [ 0 ]), t_0_inner . tile ( x . shape [ 0 ])
1118+ h = lambda_inner [ - 1 ] - lambda_inner [ 0 ]
1119+ r1 = None if order <= 1 else (lambda_inner [ 1 ] - lambda_inner [ 0 ] ) / h
1120+ r2 = None if order <= 2 else ( lambda_inner [ 2 ] - lambda_inner [ 0 ]) / h
1121+ x = self . singlestep_dpm_solver_update ( x , vec_s , vec_t , order , solver_type = solver_type , r1 = r1 , r2 = r2 )
1122+ if denoise_to_zero :
1123+ x = self .denoise_to_zero_fn (x , torch .ones ((x .shape [0 ],)).to (device ) * t_0 )
10971124 return x
10981125
10991126
0 commit comments