2222from  ..configuration_utils  import  ConfigMixin , register_to_config 
2323from  ..utils  import  BaseOutput , is_scipy_available 
2424from  .scheduling_utils  import  KarrasDiffusionSchedulers , SchedulerMixin 
25- 
26- 
27- if  is_scipy_available ():
28-     import  scipy .stats 
25+ from  .sigmas  import  BetaSigmas , ExponentialSigmas , KarrasSigmas 
2926
3027
3128@dataclass  
@@ -119,21 +116,14 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
119116            Clip the predicted sample for numerical stability. 
120117        clip_sample_range (`float`, defaults to 1.0): 
121118            The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. 
122-         use_karras_sigmas (`bool`, *optional*, defaults to `False`): 
123-             Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, 
124-             the sigmas are determined according to a sequence of noise levels {σi}. 
125-         use_exponential_sigmas (`bool`, *optional*, defaults to `False`): 
126-             Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. 
127-         use_beta_sigmas (`bool`, *optional*, defaults to `False`): 
128-             Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta 
129-             Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. 
130119        timestep_spacing (`str`, defaults to `"linspace"`): 
131120            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 
132121            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 
133122        steps_offset (`int`, defaults to 0): 
134123            An offset added to the inference steps, as required by some model families. 
135124    """ 
136125
126+     ignore_for_config  =  ["sigma_schedule" ]
137127    _compatibles  =  [e .name  for  e  in  KarrasDiffusionSchedulers ]
138128    order  =  2 
139129
@@ -146,20 +136,14 @@ def __init__(
146136        beta_schedule : str  =  "linear" ,
147137        trained_betas : Optional [Union [np .ndarray , List [float ]]] =  None ,
148138        prediction_type : str  =  "epsilon" ,
149-         use_karras_sigmas : Optional [bool ] =  False ,
150-         use_exponential_sigmas : Optional [bool ] =  False ,
151-         use_beta_sigmas : Optional [bool ] =  False ,
139+         sigma_schedule : Optional [Union [BetaSigmas , ExponentialSigmas , KarrasSigmas ]] =  None ,
152140        clip_sample : Optional [bool ] =  False ,
153141        clip_sample_range : float  =  1.0 ,
154142        timestep_spacing : str  =  "linspace" ,
155143        steps_offset : int  =  0 ,
156144    ):
157-         if  self . config . use_beta_sigmas  and  not  is_scipy_available ():
145+         if  isinstance ( sigma_schedule ,  BetaSigmas )  and  not  is_scipy_available ():
158146            raise  ImportError ("Make sure to install scipy if you want to use beta sigmas." )
159-         if  sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) >  1 :
160-             raise  ValueError (
161-                 "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." 
162-             )
163147        if  trained_betas  is  not None :
164148            self .betas  =  torch .tensor (trained_betas , dtype = torch .float32 )
165149        elif  beta_schedule  ==  "linear" :
@@ -178,9 +162,10 @@ def __init__(
178162        self .alphas  =  1.0  -  self .betas 
179163        self .alphas_cumprod  =  torch .cumprod (self .alphas , dim = 0 )
180164
165+         self .sigma_schedule  =  sigma_schedule 
166+ 
181167        #  set all values 
182168        self .set_timesteps (num_train_timesteps , None , num_train_timesteps )
183-         self .use_karras_sigmas  =  use_karras_sigmas 
184169
185170        self ._step_index  =  None 
186171        self ._begin_index  =  None 
@@ -287,12 +272,8 @@ def set_timesteps(
287272            raise  ValueError ("Must pass exactly one of `num_inference_steps` or `custom_timesteps`." )
288273        if  num_inference_steps  is  not None  and  timesteps  is  not None :
289274            raise  ValueError ("Can only pass one of `num_inference_steps` or `custom_timesteps`." )
290-         if  timesteps  is  not None  and  self .config .use_karras_sigmas :
291-             raise  ValueError ("Cannot use `timesteps` with `config.use_karras_sigmas = True`" )
292-         if  timesteps  is  not None  and  self .config .use_exponential_sigmas :
293-             raise  ValueError ("Cannot set `timesteps` with `config.use_exponential_sigmas = True`." )
294-         if  timesteps  is  not None  and  self .config .use_beta_sigmas :
295-             raise  ValueError ("Cannot set `timesteps` with `config.use_beta_sigmas = True`." )
275+         if  timesteps  is  not None  and  self .sigma_schedule  is  not None :
276+             raise  ValueError ("Cannot use `timesteps` with `sigma_schedule`" )
296277
297278        num_inference_steps  =  num_inference_steps  or  len (timesteps )
298279        self .num_inference_steps  =  num_inference_steps 
@@ -325,14 +306,8 @@ def set_timesteps(
325306        log_sigmas  =  np .log (sigmas )
326307        sigmas  =  np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
327308
328-         if  self .config .use_karras_sigmas :
329-             sigmas  =  self ._convert_to_karras (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
330-             timesteps  =  np .array ([self ._sigma_to_t (sigma , log_sigmas ) for  sigma  in  sigmas ])
331-         elif  self .config .use_exponential_sigmas :
332-             sigmas  =  self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
333-             timesteps  =  np .array ([self ._sigma_to_t (sigma , log_sigmas ) for  sigma  in  sigmas ])
334-         elif  self .config .use_beta_sigmas :
335-             sigmas  =  self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = num_inference_steps )
309+         if  self .sigma_schedule  is  not None :
310+             sigmas  =  self .sigma_schedule (sigmas )
336311            timesteps  =  np .array ([self ._sigma_to_t (sigma , log_sigmas ) for  sigma  in  sigmas ])
337312
338313        sigmas  =  np .concatenate ([sigmas , [0.0 ]]).astype (np .float32 )
@@ -376,86 +351,6 @@ def _sigma_to_t(self, sigma, log_sigmas):
376351        t  =  t .reshape (sigma .shape )
377352        return  t 
378353
379-     # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras 
380-     def  _convert_to_karras (self , in_sigmas : torch .Tensor , num_inference_steps ) ->  torch .Tensor :
381-         """Constructs the noise schedule of Karras et al. (2022).""" 
382- 
383-         # Hack to make sure that other schedulers which copy this function don't break 
384-         # TODO: Add this logic to the other schedulers 
385-         if  hasattr (self .config , "sigma_min" ):
386-             sigma_min  =  self .config .sigma_min 
387-         else :
388-             sigma_min  =  None 
389- 
390-         if  hasattr (self .config , "sigma_max" ):
391-             sigma_max  =  self .config .sigma_max 
392-         else :
393-             sigma_max  =  None 
394- 
395-         sigma_min  =  sigma_min  if  sigma_min  is  not None  else  in_sigmas [- 1 ].item ()
396-         sigma_max  =  sigma_max  if  sigma_max  is  not None  else  in_sigmas [0 ].item ()
397- 
398-         rho  =  7.0   # 7.0 is the value used in the paper 
399-         ramp  =  np .linspace (0 , 1 , num_inference_steps )
400-         min_inv_rho  =  sigma_min  **  (1  /  rho )
401-         max_inv_rho  =  sigma_max  **  (1  /  rho )
402-         sigmas  =  (max_inv_rho  +  ramp  *  (min_inv_rho  -  max_inv_rho )) **  rho 
403-         return  sigmas 
404- 
405-     # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential 
406-     def  _convert_to_exponential (self , in_sigmas : torch .Tensor , num_inference_steps : int ) ->  torch .Tensor :
407-         """Constructs an exponential noise schedule.""" 
408- 
409-         # Hack to make sure that other schedulers which copy this function don't break 
410-         # TODO: Add this logic to the other schedulers 
411-         if  hasattr (self .config , "sigma_min" ):
412-             sigma_min  =  self .config .sigma_min 
413-         else :
414-             sigma_min  =  None 
415- 
416-         if  hasattr (self .config , "sigma_max" ):
417-             sigma_max  =  self .config .sigma_max 
418-         else :
419-             sigma_max  =  None 
420- 
421-         sigma_min  =  sigma_min  if  sigma_min  is  not None  else  in_sigmas [- 1 ].item ()
422-         sigma_max  =  sigma_max  if  sigma_max  is  not None  else  in_sigmas [0 ].item ()
423- 
424-         sigmas  =  np .exp (np .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ))
425-         return  sigmas 
426- 
427-     # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta 
428-     def  _convert_to_beta (
429-         self , in_sigmas : torch .Tensor , num_inference_steps : int , alpha : float  =  0.6 , beta : float  =  0.6 
430-     ) ->  torch .Tensor :
431-         """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" 
432- 
433-         # Hack to make sure that other schedulers which copy this function don't break 
434-         # TODO: Add this logic to the other schedulers 
435-         if  hasattr (self .config , "sigma_min" ):
436-             sigma_min  =  self .config .sigma_min 
437-         else :
438-             sigma_min  =  None 
439- 
440-         if  hasattr (self .config , "sigma_max" ):
441-             sigma_max  =  self .config .sigma_max 
442-         else :
443-             sigma_max  =  None 
444- 
445-         sigma_min  =  sigma_min  if  sigma_min  is  not None  else  in_sigmas [- 1 ].item ()
446-         sigma_max  =  sigma_max  if  sigma_max  is  not None  else  in_sigmas [0 ].item ()
447- 
448-         sigmas  =  np .array (
449-             [
450-                 sigma_min  +  (ppf  *  (sigma_max  -  sigma_min ))
451-                 for  ppf  in  [
452-                     scipy .stats .beta .ppf (timestep , alpha , beta )
453-                     for  timestep  in  1  -  np .linspace (0 , 1 , num_inference_steps )
454-                 ]
455-             ]
456-         )
457-         return  sigmas 
458- 
459354    @property  
460355    def  state_in_first_order (self ):
461356        return  self .dt  is  None 
0 commit comments