Skip to content

Commit c4a8979

Browse files
authored
Add beta sigmas to other schedulers and update docs (#9538)
1 parent f9fd511 commit c4a8979

12 files changed

+551
-28
lines changed

docs/source/en/api/schedulers/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso
5252
| sgm_uniform | init with `timestep_spacing="trailing"` |
5353
| simple | init with `timestep_spacing="trailing"` |
5454
| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` |
55+
| beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` |
5556

5657
All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.
5758

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@
2222
import torch
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
25-
from ..utils import deprecate
25+
from ..utils import deprecate, is_scipy_available
2626
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
2727

2828

29+
if is_scipy_available():
30+
import scipy.stats
31+
32+
2933
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
3034
def betas_for_alpha_bar(
3135
num_diffusion_timesteps,
@@ -113,6 +117,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
113117
the sigmas are determined according to a sequence of noise levels {σi}.
114118
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
115119
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
120+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
121+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
122+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
116123
timestep_spacing (`str`, defaults to `"linspace"`):
117124
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
118125
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -141,11 +148,16 @@ def __init__(
141148
lower_order_final: bool = True,
142149
use_karras_sigmas: Optional[bool] = False,
143150
use_exponential_sigmas: Optional[bool] = False,
151+
use_beta_sigmas: Optional[bool] = False,
144152
timestep_spacing: str = "linspace",
145153
steps_offset: int = 0,
146154
):
147-
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
148-
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
155+
if self.config.use_beta_sigmas and not is_scipy_available():
156+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
157+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
158+
raise ValueError(
159+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
160+
)
149161
if trained_betas is not None:
150162
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
151163
elif beta_schedule == "linear":
@@ -263,6 +275,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
263275
elif self.config.use_exponential_sigmas:
264276
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
265277
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
278+
elif self.config.use_beta_sigmas:
279+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
280+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
266281
else:
267282
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
268283
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -396,6 +411,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
396411
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
397412
return sigmas
398413

414+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
415+
def _convert_to_beta(
416+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
417+
) -> torch.Tensor:
418+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
419+
420+
# Hack to make sure that other schedulers which copy this function don't break
421+
# TODO: Add this logic to the other schedulers
422+
if hasattr(self.config, "sigma_min"):
423+
sigma_min = self.config.sigma_min
424+
else:
425+
sigma_min = None
426+
427+
if hasattr(self.config, "sigma_max"):
428+
sigma_max = self.config.sigma_max
429+
else:
430+
sigma_max = None
431+
432+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
433+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
434+
435+
sigmas = torch.Tensor(
436+
[
437+
sigma_min + (ppf * (sigma_max - sigma_min))
438+
for ppf in [
439+
scipy.stats.beta.ppf(timestep, alpha, beta)
440+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
441+
]
442+
]
443+
)
444+
return sigmas
445+
399446
def convert_model_output(
400447
self,
401448
model_output: torch.Tensor,

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@
2121
import torch
2222

2323
from ..configuration_utils import ConfigMixin, register_to_config
24-
from ..utils import deprecate
24+
from ..utils import deprecate, is_scipy_available
2525
from ..utils.torch_utils import randn_tensor
2626
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
2727

2828

29+
if is_scipy_available():
30+
import scipy.stats
31+
32+
2933
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
3034
def betas_for_alpha_bar(
3135
num_diffusion_timesteps,
@@ -163,6 +167,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
163167
the sigmas are determined according to a sequence of noise levels {σi}.
164168
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
165169
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
170+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
171+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
172+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
166173
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
167174
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
168175
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
@@ -209,6 +216,7 @@ def __init__(
209216
euler_at_final: bool = False,
210217
use_karras_sigmas: Optional[bool] = False,
211218
use_exponential_sigmas: Optional[bool] = False,
219+
use_beta_sigmas: Optional[bool] = False,
212220
use_lu_lambdas: Optional[bool] = False,
213221
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
214222
lambda_min_clipped: float = -float("inf"),
@@ -217,8 +225,12 @@ def __init__(
217225
steps_offset: int = 0,
218226
rescale_betas_zero_snr: bool = False,
219227
):
220-
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
221-
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
228+
if self.config.use_beta_sigmas and not is_scipy_available():
229+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
230+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
231+
raise ValueError(
232+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
233+
)
222234
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
223235
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
224236
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
@@ -337,6 +349,8 @@ def set_timesteps(
337349
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
338350
if timesteps is not None and self.config.use_exponential_sigmas:
339351
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
352+
if timesteps is not None and self.config.use_beta_sigmas:
353+
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
340354

341355
if timesteps is not None:
342356
timesteps = np.array(timesteps).astype(np.int64)
@@ -388,6 +402,9 @@ def set_timesteps(
388402
elif self.config.use_exponential_sigmas:
389403
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
390404
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
405+
elif self.config.use_beta_sigmas:
406+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
407+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
391408
else:
392409
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
393410

@@ -542,6 +559,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
542559
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
543560
return sigmas
544561

562+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
563+
def _convert_to_beta(
564+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
565+
) -> torch.Tensor:
566+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
567+
568+
# Hack to make sure that other schedulers which copy this function don't break
569+
# TODO: Add this logic to the other schedulers
570+
if hasattr(self.config, "sigma_min"):
571+
sigma_min = self.config.sigma_min
572+
else:
573+
sigma_min = None
574+
575+
if hasattr(self.config, "sigma_max"):
576+
sigma_max = self.config.sigma_max
577+
else:
578+
sigma_max = None
579+
580+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
581+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
582+
583+
sigmas = torch.Tensor(
584+
[
585+
sigma_min + (ppf * (sigma_max - sigma_min))
586+
for ppf in [
587+
scipy.stats.beta.ppf(timestep, alpha, beta)
588+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
589+
]
590+
]
591+
)
592+
return sigmas
593+
545594
def convert_model_output(
546595
self,
547596
model_output: torch.Tensor,

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@
2121
import torch
2222

2323
from ..configuration_utils import ConfigMixin, register_to_config
24-
from ..utils import deprecate
24+
from ..utils import deprecate, is_scipy_available
2525
from ..utils.torch_utils import randn_tensor
2626
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
2727

2828

29+
if is_scipy_available():
30+
import scipy.stats
31+
32+
2933
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
3034
def betas_for_alpha_bar(
3135
num_diffusion_timesteps,
@@ -126,6 +130,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
126130
the sigmas are determined according to a sequence of noise levels {σi}.
127131
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
128132
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
133+
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
134+
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
135+
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
129136
lambda_min_clipped (`float`, defaults to `-inf`):
130137
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
131138
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -161,13 +168,18 @@ def __init__(
161168
euler_at_final: bool = False,
162169
use_karras_sigmas: Optional[bool] = False,
163170
use_exponential_sigmas: Optional[bool] = False,
171+
use_beta_sigmas: Optional[bool] = False,
164172
lambda_min_clipped: float = -float("inf"),
165173
variance_type: Optional[str] = None,
166174
timestep_spacing: str = "linspace",
167175
steps_offset: int = 0,
168176
):
169-
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
170-
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
177+
if self.config.use_beta_sigmas and not is_scipy_available():
178+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
179+
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
180+
raise ValueError(
181+
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
182+
)
171183
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
172184
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
173185
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
@@ -219,6 +231,7 @@ def __init__(
219231
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
220232
self.use_karras_sigmas = use_karras_sigmas
221233
self.use_exponential_sigmas = use_exponential_sigmas
234+
self.use_beta_sigmas = use_beta_sigmas
222235

223236
@property
224237
def step_index(self):
@@ -276,6 +289,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
276289
elif self.config.use_exponential_sigmas:
277290
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
278291
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
292+
elif self.config.use_beta_sigmas:
293+
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
294+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
279295
else:
280296
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
281297
sigma_max = (
@@ -416,6 +432,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
416432
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
417433
return sigmas
418434

435+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
436+
def _convert_to_beta(
437+
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
438+
) -> torch.Tensor:
439+
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
440+
441+
# Hack to make sure that other schedulers which copy this function don't break
442+
# TODO: Add this logic to the other schedulers
443+
if hasattr(self.config, "sigma_min"):
444+
sigma_min = self.config.sigma_min
445+
else:
446+
sigma_min = None
447+
448+
if hasattr(self.config, "sigma_max"):
449+
sigma_max = self.config.sigma_max
450+
else:
451+
sigma_max = None
452+
453+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
454+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
455+
456+
sigmas = torch.Tensor(
457+
[
458+
sigma_min + (ppf * (sigma_max - sigma_min))
459+
for ppf in [
460+
scipy.stats.beta.ppf(timestep, alpha, beta)
461+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
462+
]
463+
]
464+
)
465+
return sigmas
466+
419467
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
420468
def convert_model_output(
421469
self,

0 commit comments

Comments
 (0)