2222from ..configuration_utils import ConfigMixin , register_to_config
2323from ..utils import BaseOutput , is_scipy_available
2424from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin
25-
26-
27- if is_scipy_available ():
28- import scipy .stats
25+ from .sigmas import BetaSigmas , ExponentialSigmas , KarrasSigmas
2926
3027
3128@dataclass
@@ -119,21 +116,14 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
119116 Clip the predicted sample for numerical stability.
120117 clip_sample_range (`float`, defaults to 1.0):
121118 The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
122- use_karras_sigmas (`bool`, *optional*, defaults to `False`):
123- Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
124- the sigmas are determined according to a sequence of noise levels {σi}.
125- use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
126- Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
127- use_beta_sigmas (`bool`, *optional*, defaults to `False`):
128- Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
129- Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
130119 timestep_spacing (`str`, defaults to `"linspace"`):
131120 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
132121 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
133122 steps_offset (`int`, defaults to 0):
134123 An offset added to the inference steps, as required by some model families.
135124 """
136125
126+ ignore_for_config = ["sigma_schedule" ]
137127 _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
138128 order = 2
139129
@@ -146,20 +136,14 @@ def __init__(
146136 beta_schedule : str = "linear" ,
147137 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
148138 prediction_type : str = "epsilon" ,
149- use_karras_sigmas : Optional [bool ] = False ,
150- use_exponential_sigmas : Optional [bool ] = False ,
151- use_beta_sigmas : Optional [bool ] = False ,
139+ sigma_schedule : Optional [Union [BetaSigmas , ExponentialSigmas , KarrasSigmas ]] = None ,
152140 clip_sample : Optional [bool ] = False ,
153141 clip_sample_range : float = 1.0 ,
154142 timestep_spacing : str = "linspace" ,
155143 steps_offset : int = 0 ,
156144 ):
157- if self . config . use_beta_sigmas and not is_scipy_available ():
145+ if isinstance ( sigma_schedule , BetaSigmas ) and not is_scipy_available ():
158146 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
159- if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
160- raise ValueError (
161- "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
162- )
163147 if trained_betas is not None :
164148 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
165149 elif beta_schedule == "linear" :
@@ -178,9 +162,10 @@ def __init__(
178162 self .alphas = 1.0 - self .betas
179163 self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
180164
165+ self .sigma_schedule = sigma_schedule
166+
181167 # set all values
182168 self .set_timesteps (num_train_timesteps , None , num_train_timesteps )
183- self .use_karras_sigmas = use_karras_sigmas
184169
185170 self ._step_index = None
186171 self ._begin_index = None
@@ -287,12 +272,8 @@ def set_timesteps(
287272 raise ValueError ("Must pass exactly one of `num_inference_steps` or `custom_timesteps`." )
288273 if num_inference_steps is not None and timesteps is not None :
289274 raise ValueError ("Can only pass one of `num_inference_steps` or `custom_timesteps`." )
290- if timesteps is not None and self .config .use_karras_sigmas :
291- raise ValueError ("Cannot use `timesteps` with `config.use_karras_sigmas = True`" )
292- if timesteps is not None and self .config .use_exponential_sigmas :
293- raise ValueError ("Cannot set `timesteps` with `config.use_exponential_sigmas = True`." )
294- if timesteps is not None and self .config .use_beta_sigmas :
295- raise ValueError ("Cannot set `timesteps` with `config.use_beta_sigmas = True`." )
275+ if timesteps is not None and self .sigma_schedule is not None :
276+ raise ValueError ("Cannot use `timesteps` with `sigma_schedule`" )
296277
297278 num_inference_steps = num_inference_steps or len (timesteps )
298279 self .num_inference_steps = num_inference_steps
@@ -325,14 +306,8 @@ def set_timesteps(
325306 log_sigmas = np .log (sigmas )
326307 sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
327308
328- if self .config .use_karras_sigmas :
329- sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
330- timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
331- elif self .config .use_exponential_sigmas :
332- sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
333- timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
334- elif self .config .use_beta_sigmas :
335- sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
309+ if self .sigma_schedule is not None :
310+ sigmas = self .sigma_schedule (sigmas )
336311 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
337312
338313 sigmas = np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
@@ -376,86 +351,6 @@ def _sigma_to_t(self, sigma, log_sigmas):
376351 t = t .reshape (sigma .shape )
377352 return t
378353
379- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
380- def _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps ) -> torch .Tensor :
381- """Constructs the noise schedule of Karras et al. (2022)."""
382-
383- # Hack to make sure that other schedulers which copy this function don't break
384- # TODO: Add this logic to the other schedulers
385- if hasattr (self .config , "sigma_min" ):
386- sigma_min = self .config .sigma_min
387- else :
388- sigma_min = None
389-
390- if hasattr (self .config , "sigma_max" ):
391- sigma_max = self .config .sigma_max
392- else :
393- sigma_max = None
394-
395- sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
396- sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
397-
398- rho = 7.0 # 7.0 is the value used in the paper
399- ramp = np .linspace (0 , 1 , num_inference_steps )
400- min_inv_rho = sigma_min ** (1 / rho )
401- max_inv_rho = sigma_max ** (1 / rho )
402- sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
403- return sigmas
404-
405- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
406- def _convert_to_exponential (self , in_sigmas : torch .Tensor , num_inference_steps : int ) -> torch .Tensor :
407- """Constructs an exponential noise schedule."""
408-
409- # Hack to make sure that other schedulers which copy this function don't break
410- # TODO: Add this logic to the other schedulers
411- if hasattr (self .config , "sigma_min" ):
412- sigma_min = self .config .sigma_min
413- else :
414- sigma_min = None
415-
416- if hasattr (self .config , "sigma_max" ):
417- sigma_max = self .config .sigma_max
418- else :
419- sigma_max = None
420-
421- sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
422- sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
423-
424- sigmas = np .exp (np .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ))
425- return sigmas
426-
427- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
428- def _convert_to_beta (
429- self , in_sigmas : torch .Tensor , num_inference_steps : int , alpha : float = 0.6 , beta : float = 0.6
430- ) -> torch .Tensor :
431- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
432-
433- # Hack to make sure that other schedulers which copy this function don't break
434- # TODO: Add this logic to the other schedulers
435- if hasattr (self .config , "sigma_min" ):
436- sigma_min = self .config .sigma_min
437- else :
438- sigma_min = None
439-
440- if hasattr (self .config , "sigma_max" ):
441- sigma_max = self .config .sigma_max
442- else :
443- sigma_max = None
444-
445- sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
446- sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
447-
448- sigmas = np .array (
449- [
450- sigma_min + (ppf * (sigma_max - sigma_min ))
451- for ppf in [
452- scipy .stats .beta .ppf (timestep , alpha , beta )
453- for timestep in 1 - np .linspace (0 , 1 , num_inference_steps )
454- ]
455- ]
456- )
457- return sigmas
458-
459354 @property
460355 def state_in_first_order (self ):
461356 return self .dt is None
0 commit comments