2020import torch
2121
2222from ..configuration_utils import ConfigMixin , register_to_config
23- from ..utils import BaseOutput , logging
23+ from ..utils import BaseOutput , is_scipy_available , logging
2424from .scheduling_utils import SchedulerMixin
2525
2626
27+ if is_scipy_available ():
28+ import scipy .stats
29+
2730logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
2831
2932
@@ -72,7 +75,16 @@ def __init__(
7275 base_image_seq_len : Optional [int ] = 256 ,
7376 max_image_seq_len : Optional [int ] = 4096 ,
7477 invert_sigmas : bool = False ,
78+ use_karras_sigmas : Optional [bool ] = False ,
79+ use_exponential_sigmas : Optional [bool ] = False ,
80+ use_beta_sigmas : Optional [bool ] = False ,
7581 ):
82+ if self .config .use_beta_sigmas and not is_scipy_available ():
83+ raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
84+ if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
85+ raise ValueError (
86+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
87+ )
7688 timesteps = np .linspace (1 , num_train_timesteps , num_train_timesteps , dtype = np .float32 )[::- 1 ].copy ()
7789 timesteps = torch .from_numpy (timesteps ).to (dtype = torch .float32 )
7890
@@ -185,12 +197,14 @@ def set_timesteps(
185197 device (`str` or `torch.device`, *optional*):
186198 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
187199 """
200+ if num_inference_steps is None :
201+ num_inference_steps = len (sigmas ) - 1
202+ self .num_inference_steps = num_inference_steps
188203
189204 if self .config .use_dynamic_shifting and mu is None :
190205 raise ValueError (" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" )
191206
192207 if sigmas is None :
193- self .num_inference_steps = num_inference_steps
194208 timesteps = np .linspace (
195209 self ._sigma_to_t (self .sigma_max ), self ._sigma_to_t (self .sigma_min ), num_inference_steps
196210 )
@@ -202,6 +216,15 @@ def set_timesteps(
202216 else :
203217 sigmas = self .config .shift * sigmas / (1 + (self .config .shift - 1 ) * sigmas )
204218
219+ if self .config .use_karras_sigmas :
220+ sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = len (sigmas ))
221+
222+ elif self .config .use_exponential_sigmas :
223+ sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = len (sigmas ))
224+
225+ elif self .config .use_beta_sigmas :
226+ sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = len (sigmas ))
227+
205228 sigmas = torch .from_numpy (sigmas ).to (dtype = torch .float32 , device = device )
206229 timesteps = sigmas * self .config .num_train_timesteps
207230
@@ -314,5 +337,85 @@ def step(
314337
315338 return FlowMatchEulerDiscreteSchedulerOutput (prev_sample = prev_sample )
316339
340+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
341+ def _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps ) -> torch .Tensor :
342+ """Constructs the noise schedule of Karras et al. (2022)."""
343+
344+ # Hack to make sure that other schedulers which copy this function don't break
345+ # TODO: Add this logic to the other schedulers
346+ if hasattr (self .config , "sigma_min" ):
347+ sigma_min = self .config .sigma_min
348+ else :
349+ sigma_min = None
350+
351+ if hasattr (self .config , "sigma_max" ):
352+ sigma_max = self .config .sigma_max
353+ else :
354+ sigma_max = None
355+
356+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
357+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
358+
359+ rho = 7.0 # 7.0 is the value used in the paper
360+ ramp = np .linspace (0 , 1 , num_inference_steps )
361+ min_inv_rho = sigma_min ** (1 / rho )
362+ max_inv_rho = sigma_max ** (1 / rho )
363+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
364+ return sigmas
365+
366+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
367+ def _convert_to_exponential (self , in_sigmas : torch .Tensor , num_inference_steps : int ) -> torch .Tensor :
368+ """Constructs an exponential noise schedule."""
369+
370+ # Hack to make sure that other schedulers which copy this function don't break
371+ # TODO: Add this logic to the other schedulers
372+ if hasattr (self .config , "sigma_min" ):
373+ sigma_min = self .config .sigma_min
374+ else :
375+ sigma_min = None
376+
377+ if hasattr (self .config , "sigma_max" ):
378+ sigma_max = self .config .sigma_max
379+ else :
380+ sigma_max = None
381+
382+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
383+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
384+
385+ sigmas = np .exp (np .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ))
386+ return sigmas
387+
388+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
389+ def _convert_to_beta (
390+ self , in_sigmas : torch .Tensor , num_inference_steps : int , alpha : float = 0.6 , beta : float = 0.6
391+ ) -> torch .Tensor :
392+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
393+
394+ # Hack to make sure that other schedulers which copy this function don't break
395+ # TODO: Add this logic to the other schedulers
396+ if hasattr (self .config , "sigma_min" ):
397+ sigma_min = self .config .sigma_min
398+ else :
399+ sigma_min = None
400+
401+ if hasattr (self .config , "sigma_max" ):
402+ sigma_max = self .config .sigma_max
403+ else :
404+ sigma_max = None
405+
406+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
407+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
408+
409+ sigmas = np .array (
410+ [
411+ sigma_min + (ppf * (sigma_max - sigma_min ))
412+ for ppf in [
413+ scipy .stats .beta .ppf (timestep , alpha , beta )
414+ for timestep in 1 - np .linspace (0 , 1 , num_inference_steps )
415+ ]
416+ ]
417+ )
418+ return sigmas
419+
317420 def __len__ (self ):
318421 return self .config .num_train_timesteps
0 commit comments