21
21
import torch
22
22
23
23
from ..configuration_utils import ConfigMixin , register_to_config
24
- from ..utils import deprecate
24
+ from ..utils import deprecate , is_scipy_available
25
25
from ..utils .torch_utils import randn_tensor
26
26
from .scheduling_utils import KarrasDiffusionSchedulers , SchedulerMixin , SchedulerOutput
27
27
28
28
29
+ if is_scipy_available ():
30
+ import scipy .stats
31
+
32
+
29
33
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
30
34
def betas_for_alpha_bar (
31
35
num_diffusion_timesteps ,
@@ -163,6 +167,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
163
167
the sigmas are determined according to a sequence of noise levels {σi}.
164
168
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
165
169
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
170
+ use_beta_sigmas (`bool`, *optional*, defaults to `False`):
171
+ Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
172
+ Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
166
173
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
167
174
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
168
175
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
@@ -209,6 +216,7 @@ def __init__(
209
216
euler_at_final : bool = False ,
210
217
use_karras_sigmas : Optional [bool ] = False ,
211
218
use_exponential_sigmas : Optional [bool ] = False ,
219
+ use_beta_sigmas : Optional [bool ] = False ,
212
220
use_lu_lambdas : Optional [bool ] = False ,
213
221
final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
214
222
lambda_min_clipped : float = - float ("inf" ),
@@ -217,8 +225,12 @@ def __init__(
217
225
steps_offset : int = 0 ,
218
226
rescale_betas_zero_snr : bool = False ,
219
227
):
220
- if sum ([self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
221
- raise ValueError ("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." )
228
+ if self .config .use_beta_sigmas and not is_scipy_available ():
229
+ raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
230
+ if sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) > 1 :
231
+ raise ValueError (
232
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
233
+ )
222
234
if algorithm_type in ["dpmsolver" , "sde-dpmsolver" ]:
223
235
deprecation_message = f"algorithm_type { algorithm_type } is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
224
236
deprecate ("algorithm_types dpmsolver and sde-dpmsolver" , "1.0.0" , deprecation_message )
@@ -337,6 +349,8 @@ def set_timesteps(
337
349
raise ValueError ("Cannot use `timesteps` with `config.use_lu_lambdas = True`" )
338
350
if timesteps is not None and self .config .use_exponential_sigmas :
339
351
raise ValueError ("Cannot set `timesteps` with `config.use_exponential_sigmas = True`." )
352
+ if timesteps is not None and self .config .use_beta_sigmas :
353
+ raise ValueError ("Cannot set `timesteps` with `config.use_beta_sigmas = True`." )
340
354
341
355
if timesteps is not None :
342
356
timesteps = np .array (timesteps ).astype (np .int64 )
@@ -388,6 +402,9 @@ def set_timesteps(
388
402
elif self .config .use_exponential_sigmas :
389
403
sigmas = self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
390
404
timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
405
+ elif self .config .use_beta_sigmas :
406
+ sigmas = self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
407
+ timesteps = np .array ([self ._sigma_to_t (sigma , log_sigmas ) for sigma in sigmas ])
391
408
else :
392
409
sigmas = np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
393
410
@@ -542,6 +559,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
542
559
sigmas = torch .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ).exp ()
543
560
return sigmas
544
561
562
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
563
+ def _convert_to_beta (
564
+ self , in_sigmas : torch .Tensor , num_inference_steps : int , alpha : float = 0.6 , beta : float = 0.6
565
+ ) -> torch .Tensor :
566
+ """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
567
+
568
+ # Hack to make sure that other schedulers which copy this function don't break
569
+ # TODO: Add this logic to the other schedulers
570
+ if hasattr (self .config , "sigma_min" ):
571
+ sigma_min = self .config .sigma_min
572
+ else :
573
+ sigma_min = None
574
+
575
+ if hasattr (self .config , "sigma_max" ):
576
+ sigma_max = self .config .sigma_max
577
+ else :
578
+ sigma_max = None
579
+
580
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas [- 1 ].item ()
581
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas [0 ].item ()
582
+
583
+ sigmas = torch .Tensor (
584
+ [
585
+ sigma_min + (ppf * (sigma_max - sigma_min ))
586
+ for ppf in [
587
+ scipy .stats .beta .ppf (timestep , alpha , beta )
588
+ for timestep in 1 - np .linspace (0 , 1 , num_inference_steps )
589
+ ]
590
+ ]
591
+ )
592
+ return sigmas
593
+
545
594
def convert_model_output (
546
595
self ,
547
596
model_output : torch .Tensor ,
0 commit comments