@@ -114,7 +114,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
114114 lower_order_final (`bool`, default `True`):
115115 whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
116116 find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10.
117-
117+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
118+ This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
119+ noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
120+ of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
118121 """
119122
120123 _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -136,6 +139,7 @@ def __init__(
136139 algorithm_type : str = "dpmsolver++" ,
137140 solver_type : str = "midpoint" ,
138141 lower_order_final : bool = True ,
142+ use_karras_sigmas : Optional [bool ] = False ,
139143 ):
140144 if trained_betas is not None :
141145 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -181,6 +185,7 @@ def __init__(
181185 self .timesteps = torch .from_numpy (timesteps )
182186 self .model_outputs = [None ] * solver_order
183187 self .lower_order_nums = 0
188+ self .use_karras_sigmas = use_karras_sigmas
184189
185190 def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
186191 """
@@ -199,6 +204,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
199204 .astype (np .int64 )
200205 )
201206
207+ if self .use_karras_sigmas :
208+ sigmas = np .array (((1 - self .alphas_cumprod ) / self .alphas_cumprod ) ** 0.5 )
209+ log_sigmas = np .log (sigmas )
210+ sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
211+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ]).round ()
212+ timesteps = np .flip (timesteps ).copy ().astype (np .int64 )
213+
202214 # when num_inference_steps == num_train_timesteps, we can end up with
203215 # duplicates in timesteps.
204216 _ , unique_indices = np .unique (timesteps , return_index = True )
@@ -248,6 +260,44 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
248260
249261 return sample
250262
263+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
264+ def _sigma_to_t (self , sigma , log_sigmas ):
265+ # get log sigma
266+ log_sigma = np .log (sigma )
267+
268+ # get distribution
269+ dists = log_sigma - log_sigmas [:, np .newaxis ]
270+
271+ # get sigmas range
272+ low_idx = np .cumsum ((dists >= 0 ), axis = 0 ).argmax (axis = 0 ).clip (max = log_sigmas .shape [0 ] - 2 )
273+ high_idx = low_idx + 1
274+
275+ low = log_sigmas [low_idx ]
276+ high = log_sigmas [high_idx ]
277+
278+ # interpolate sigmas
279+ w = (low - log_sigma ) / (low - high )
280+ w = np .clip (w , 0 , 1 )
281+
282+ # transform interpolation to time range
283+ t = (1 - w ) * low_idx + w * high_idx
284+ t = t .reshape (sigma .shape )
285+ return t
286+
287+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
288+ def _convert_to_karras (self , in_sigmas : torch .FloatTensor , num_inference_steps ) -> torch .FloatTensor :
289+ """Constructs the noise schedule of Karras et al. (2022)."""
290+
291+ sigma_min : float = in_sigmas [- 1 ].item ()
292+ sigma_max : float = in_sigmas [0 ].item ()
293+
294+ rho = 7.0 # 7.0 is the value used in the paper
295+ ramp = np .linspace (0 , 1 , num_inference_steps )
296+ min_inv_rho = sigma_min ** (1 / rho )
297+ max_inv_rho = sigma_max ** (1 / rho )
298+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
299+ return sigmas
300+
251301 def convert_model_output (
252302 self , model_output : torch .FloatTensor , timestep : int , sample : torch .FloatTensor
253303 ) -> torch .FloatTensor :
0 commit comments