2222from ..configuration_utils import ConfigMixin , register_to_config
2323from ..utils import BaseOutput , is_scipy_available
2424from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin
25-
26-
27- if is_scipy_available ():
28- import scipy .stats
25+ from .sigmas import BetaSigmas , ExponentialSigmas , KarrasSigmas
2926
3027
3128@dataclass
@@ -146,20 +143,14 @@ def __init__(
146143 beta_schedule : str = "linear" ,
147144 trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
148145 prediction_type : str = "epsilon" ,
149- use_karras_sigmas : Optional [bool ] = False ,
150- use_exponential_sigmas : Optional [bool ] = False ,
151- use_beta_sigmas : Optional [bool ] = False ,
146+ sigma_schedule : Optional [Union [BetaSigmas , ExponentialSigmas , KarrasSigmas ]] = None ,
152147 clip_sample : Optional [bool ] = False ,
153148 clip_sample_range : float = 1.0 ,
154149 timestep_spacing : str = "linspace" ,
155150 steps_offset : int = 0 ,
156151 ):
157- if self . config . use_beta_sigmas and not is_scipy_available ():
152+ if isinstance ( sigma_schedule , BetaSigmas ) and not is_scipy_available ():
158153 raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
159- if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
160- raise ValueError (
161- "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
162- )
163154 if trained_betas is not None :
164155 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
165156 elif beta_schedule == "linear" :
@@ -178,9 +169,10 @@ def __init__(
178169 self .alphas = 1.0 - self .betas
179170 self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
180171
172+ self .sigma_schedule = sigma_schedule
173+
181174 # set all values
182175 self .set_timesteps (num_train_timesteps , None , num_train_timesteps )
183- self .use_karras_sigmas = use_karras_sigmas
184176
185177 self ._step_index = None
186178 self ._begin_index = None
@@ -287,12 +279,8 @@ def set_timesteps(
287279 raise ValueError ("Must pass exactly one of `num_inference_steps` or `custom_timesteps`." )
288280 if num_inference_steps is not None and timesteps is not None :
289281 raise ValueError ("Can only pass one of `num_inference_steps` or `custom_timesteps`." )
290- if timesteps is not None and self .config .use_karras_sigmas :
291- raise ValueError ("Cannot use `timesteps` with `config.use_karras_sigmas = True`" )
292- if timesteps is not None and self .config .use_exponential_sigmas :
293- raise ValueError ("Cannot set `timesteps` with `config.use_exponential_sigmas = True`." )
294- if timesteps is not None and self .config .use_beta_sigmas :
295- raise ValueError ("Cannot set `timesteps` with `config.use_beta_sigmas = True`." )
282+ if timesteps is not None and self .sigma_schedule is not None :
283+ raise ValueError ("Cannot use `timesteps` with `sigma_schedule`" )
296284
297285 num_inference_steps = num_inference_steps or len (timesteps )
298286 self .num_inference_steps = num_inference_steps
@@ -325,14 +313,8 @@ def set_timesteps(
325313 log_sigmas = np .log (sigmas )
326314 sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
327315
328- if self .config .use_karras_sigmas :
329- sigmas = self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
330- timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
331- elif self .config .use_exponential_sigmas :
332- sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
333- timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
334- elif self .config .use_beta_sigmas :
335- sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
316+ if self .sigma_schedule is not None :
317+ sigmas = self .sigma_schedule (sigmas )
336318 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
337319
338320 sigmas = np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
@@ -376,86 +358,6 @@ def _sigma_to_t(self, sigma, log_sigmas):
376358 t = t .reshape (sigma .shape )
377359 return t
378360
379- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
380- def _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps ) -> torch .Tensor :
381- """Constructs the noise schedule of Karras et al. (2022)."""
382-
383- # Hack to make sure that other schedulers which copy this function don't break
384- # TODO: Add this logic to the other schedulers
385- if hasattr (self .config , "sigma_min" ):
386- sigma_min = self .config .sigma_min
387- else :
388- sigma_min = None
389-
390- if hasattr (self .config , "sigma_max" ):
391- sigma_max = self .config .sigma_max
392- else :
393- sigma_max = None
394-
395- sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
396- sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
397-
398- rho = 7.0 # 7.0 is the value used in the paper
399- ramp = np .linspace (0 , 1 , num_inference_steps )
400- min_inv_rho = sigma_min ** (1 / rho )
401- max_inv_rho = sigma_max ** (1 / rho )
402- sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
403- return sigmas
404-
405- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
406- def _convert_to_exponential (self , in_sigmas : torch .Tensor , num_inference_steps : int ) -> torch .Tensor :
407- """Constructs an exponential noise schedule."""
408-
409- # Hack to make sure that other schedulers which copy this function don't break
410- # TODO: Add this logic to the other schedulers
411- if hasattr (self .config , "sigma_min" ):
412- sigma_min = self .config .sigma_min
413- else :
414- sigma_min = None
415-
416- if hasattr (self .config , "sigma_max" ):
417- sigma_max = self .config .sigma_max
418- else :
419- sigma_max = None
420-
421- sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
422- sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
423-
424- sigmas = np .exp (np .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ))
425- return sigmas
426-
427- # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
428- def _convert_to_beta (
429- self , in_sigmas : torch .Tensor , num_inference_steps : int , alpha : float = 0.6 , beta : float = 0.6
430- ) -> torch .Tensor :
431- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
432-
433- # Hack to make sure that other schedulers which copy this function don't break
434- # TODO: Add this logic to the other schedulers
435- if hasattr (self .config , "sigma_min" ):
436- sigma_min = self .config .sigma_min
437- else :
438- sigma_min = None
439-
440- if hasattr (self .config , "sigma_max" ):
441- sigma_max = self .config .sigma_max
442- else :
443- sigma_max = None
444-
445- sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
446- sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
447-
448- sigmas = np .array (
449- [
450- sigma_min + (ppf * (sigma_max - sigma_min ))
451- for ppf in [
452- scipy .stats .beta .ppf (timestep , alpha , beta )
453- for timestep in 1 - np .linspace (0 , 1 , num_inference_steps )
454- ]
455- ]
456- )
457- return sigmas
458-
459361 @property
460362 def state_in_first_order (self ):
461363 return self .dt is None
0 commit comments