2121import torch
2222
2323from ..configuration_utils import ConfigMixin , register_to_config
24- from ..utils import deprecate
24+ from ..utils import deprecate , is_scipy_available
2525from ..utils .torch_utils import randn_tensor
2626from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin , SchedulerOutput
2727
2828
29+ if is_scipy_available ():
30+ import scipy .stats
31+
32+
2933# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
3034def betas_for_alpha_bar (
3135 num_diffusion_timesteps ,
@@ -163,6 +167,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
163167 the sigmas are determined according to a sequence of noise levels {σi}.
164168 use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
165169 Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
170+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
171+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
172+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
166173 use_lu_lambdas (`bool`, *optional*, defaults to `False`):
167174 Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
168175 the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
@@ -209,6 +216,7 @@ def __init__(
209216 euler_at_final : bool = False ,
210217 use_karras_sigmas : Optional [bool ] = False ,
211218 use_exponential_sigmas : Optional [bool ] = False ,
219+ use_beta_sigmas : Optional [bool ] = False ,
212220 use_lu_lambdas : Optional [bool ] = False ,
213221 final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
214222 lambda_min_clipped : float = - float ("inf" ),
@@ -217,8 +225,12 @@ def __init__(
217225 steps_offset : int = 0 ,
218226 rescale_betas_zero_snr : bool = False ,
219227 ):
220- if sum ([self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
221- raise ValueError ("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." )
228+ if self .config .use_beta_sigmas and not is_scipy_available ():
229+ raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
230+ if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
231+ raise ValueError (
232+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
233+ )
222234 if algorithm_type in ["dpmsolver" , "sde-dpmsolver" ]:
223235 deprecation_message = f"algorithm_type { algorithm_type } is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
224236 deprecate ("algorithm_types dpmsolver and sde-dpmsolver" , "1.0.0" , deprecation_message )
@@ -337,6 +349,8 @@ def set_timesteps(
337349 raise ValueError ("Cannot use `timesteps` with `config.use_lu_lambdas = True`" )
338350 if timesteps is not None and self .config .use_exponential_sigmas :
339351 raise ValueError ("Cannot set `timesteps` with `config.use_exponential_sigmas = True`." )
352+ if timesteps is not None and self .config .use_beta_sigmas :
353+ raise ValueError ("Cannot set `timesteps` with `config.use_beta_sigmas = True`." )
340354
341355 if timesteps is not None :
342356 timesteps = np .array (timesteps ).astype (np .int64 )
@@ -388,6 +402,9 @@ def set_timesteps(
388402 elif self .config .use_exponential_sigmas :
389403 sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
390404 timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
405+ elif self .config .use_beta_sigmas :
406+ sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
407+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
391408 else :
392409 sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
393410
@@ -542,6 +559,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
542559 sigmas = torch .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ).exp ()
543560 return sigmas
544561
562+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
563+ def _convert_to_beta (
564+ self , in_sigmas : torch .Tensor , num_inference_steps : int , alpha : float = 0.6 , beta : float = 0.6
565+ ) -> torch .Tensor :
566+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
567+
568+ # Hack to make sure that other schedulers which copy this function don't break
569+ # TODO: Add this logic to the other schedulers
570+ if hasattr (self .config , "sigma_min" ):
571+ sigma_min = self .config .sigma_min
572+ else :
573+ sigma_min = None
574+
575+ if hasattr (self .config , "sigma_max" ):
576+ sigma_max = self .config .sigma_max
577+ else :
578+ sigma_max = None
579+
580+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
581+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
582+
583+ sigmas = torch .Tensor (
584+ [
585+ sigma_min + (ppf * (sigma_max - sigma_min ))
586+ for ppf in [
587+ scipy .stats .beta .ppf (timestep , alpha , beta )
588+ for timestep in 1 - np .linspace (0 , 1 , num_inference_steps )
589+ ]
590+ ]
591+ )
592+ return sigmas
593+
545594 def convert_model_output (
546595 self ,
547596 model_output : torch .Tensor ,
0 commit comments