Skip to content

Commit 02ad952

Browse files
committed
Separate Sigma Schedule
1 parent 6131a93 commit 02ad952

File tree

6 files changed

+225
-115
lines changed

6 files changed

+225
-115
lines changed

src/diffusers/schedulers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
7575
_import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"]
7676
_import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"]
77+
_import_structure["sigmas"] = ["BetaSigmas", "ExponentialSigmas", "KarrasSigmas"]
7778

7879
try:
7980
if not is_flax_available():
@@ -174,6 +175,7 @@
174175
from .scheduling_unipc_multistep import UniPCMultistepScheduler
175176
from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin
176177
from .scheduling_vq_diffusion import VQDiffusionScheduler
178+
from .sigmas import BetaSigmas, ExponentialSigmas, KarrasSigmas
177179

178180
try:
179181
if not is_flax_available():

src/diffusers/schedulers/scheduling_heun_discrete.py

Lines changed: 10 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
from ..configuration_utils import ConfigMixin, register_to_config
2323
from ..utils import BaseOutput, is_scipy_available
2424
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
25-
26-
27-
if is_scipy_available():
28-
import scipy.stats
25+
from .sigmas import BetaSigmas, ExponentialSigmas, KarrasSigmas
2926

3027

3128
@dataclass
@@ -119,21 +116,14 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
119116
Clip the predicted sample for numerical stability.
120117
clip_sample_range (`float`, defaults to 1.0):
121118
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
122-
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
123-
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
124-
the sigmas are determined according to a sequence of noise levels {σi}.
125-
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
126-
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
127-
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
128-
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
129-
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
130119
timestep_spacing (`str`, defaults to `"linspace"`):
131120
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
132121
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
133122
steps_offset (`int`, defaults to 0):
134123
An offset added to the inference steps, as required by some model families.
135124
"""
136125

126+
ignore_for_config = ["sigma_schedule"]
137127
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
138128
order = 2
139129

@@ -146,20 +136,14 @@ def __init__(
146136
beta_schedule: str = "linear",
147137
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
148138
prediction_type: str = "epsilon",
149-
use_karras_sigmas: Optional[bool] = False,
150-
use_exponential_sigmas: Optional[bool] = False,
151-
use_beta_sigmas: Optional[bool] = False,
139+
sigma_schedule: Optional[Union[BetaSigmas, ExponentialSigmas, KarrasSigmas]] = None,
152140
clip_sample: Optional[bool] = False,
153141
clip_sample_range: float = 1.0,
154142
timestep_spacing: str = "linspace",
155143
steps_offset: int = 0,
156144
):
157-
if self.config.use_beta_sigmas and not is_scipy_available():
145+
if isinstance(sigma_schedule, BetaSigmas) and not is_scipy_available():
158146
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
159-
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
160-
raise ValueError(
161-
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
162-
)
163147
if trained_betas is not None:
164148
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
165149
elif beta_schedule == "linear":
@@ -178,9 +162,10 @@ def __init__(
178162
self.alphas = 1.0 - self.betas
179163
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
180164

165+
self.sigma_schedule = sigma_schedule
166+
181167
# set all values
182168
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
183-
self.use_karras_sigmas = use_karras_sigmas
184169

185170
self._step_index = None
186171
self._begin_index = None
@@ -287,12 +272,8 @@ def set_timesteps(
287272
raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
288273
if num_inference_steps is not None and timesteps is not None:
289274
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
290-
if timesteps is not None and self.config.use_karras_sigmas:
291-
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
292-
if timesteps is not None and self.config.use_exponential_sigmas:
293-
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
294-
if timesteps is not None and self.config.use_beta_sigmas:
295-
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
275+
if timesteps is not None and self.sigma_schedule is not None:
276+
raise ValueError("Cannot use `timesteps` with `sigma_schedule`")
296277

297278
num_inference_steps = num_inference_steps or len(timesteps)
298279
self.num_inference_steps = num_inference_steps
@@ -325,14 +306,8 @@ def set_timesteps(
325306
log_sigmas = np.log(sigmas)
326307
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
327308

328-
if self.config.use_karras_sigmas:
329-
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
330-
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
331-
elif self.config.use_exponential_sigmas:
332-
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
333-
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
334-
elif self.config.use_beta_sigmas:
335-
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
309+
if self.sigma_schedule is not None:
310+
sigmas = self.sigma_schedule(sigmas)
336311
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
337312

338313
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
@@ -376,86 +351,6 @@ def _sigma_to_t(self, sigma, log_sigmas):
376351
t = t.reshape(sigma.shape)
377352
return t
378353

379-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
380-
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
381-
"""Constructs the noise schedule of Karras et al. (2022)."""
382-
383-
# Hack to make sure that other schedulers which copy this function don't break
384-
# TODO: Add this logic to the other schedulers
385-
if hasattr(self.config, "sigma_min"):
386-
sigma_min = self.config.sigma_min
387-
else:
388-
sigma_min = None
389-
390-
if hasattr(self.config, "sigma_max"):
391-
sigma_max = self.config.sigma_max
392-
else:
393-
sigma_max = None
394-
395-
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
396-
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
397-
398-
rho = 7.0 # 7.0 is the value used in the paper
399-
ramp = np.linspace(0, 1, num_inference_steps)
400-
min_inv_rho = sigma_min ** (1 / rho)
401-
max_inv_rho = sigma_max ** (1 / rho)
402-
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
403-
return sigmas
404-
405-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
406-
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
407-
"""Constructs an exponential noise schedule."""
408-
409-
# Hack to make sure that other schedulers which copy this function don't break
410-
# TODO: Add this logic to the other schedulers
411-
if hasattr(self.config, "sigma_min"):
412-
sigma_min = self.config.sigma_min
413-
else:
414-
sigma_min = None
415-
416-
if hasattr(self.config, "sigma_max"):
417-
sigma_max = self.config.sigma_max
418-
else:
419-
sigma_max = None
420-
421-
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
422-
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
423-
424-
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
425-
return sigmas
426-
427-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
428-
def _convert_to_beta(
429-
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
430-
) -> torch.Tensor:
431-
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
432-
433-
# Hack to make sure that other schedulers which copy this function don't break
434-
# TODO: Add this logic to the other schedulers
435-
if hasattr(self.config, "sigma_min"):
436-
sigma_min = self.config.sigma_min
437-
else:
438-
sigma_min = None
439-
440-
if hasattr(self.config, "sigma_max"):
441-
sigma_max = self.config.sigma_max
442-
else:
443-
sigma_max = None
444-
445-
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
446-
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
447-
448-
sigmas = np.array(
449-
[
450-
sigma_min + (ppf * (sigma_max - sigma_min))
451-
for ppf in [
452-
scipy.stats.beta.ppf(timestep, alpha, beta)
453-
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
454-
]
455-
]
456-
)
457-
return sigmas
458-
459354
@property
460355
def state_in_first_order(self):
461356
return self.dt is None
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING
16+
17+
from ...utils import (
18+
DIFFUSERS_SLOW_IMPORT,
19+
OptionalDependencyNotAvailable,
20+
_LazyModule,
21+
get_objects_from_module,
22+
is_torch_available,
23+
is_transformers_available,
24+
)
25+
26+
27+
_dummy_objects = {}
28+
_import_structure = {}
29+
30+
try:
31+
if not (is_transformers_available() and is_torch_available()):
32+
raise OptionalDependencyNotAvailable()
33+
except OptionalDependencyNotAvailable:
34+
from ...utils import dummy_pt_objects # noqa F403
35+
36+
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
37+
else:
38+
_import_structure["beta_sigmas"] = ["BetaSigmas"]
39+
_import_structure["exponential_sigmas"] = ["ExponentialSigmas"]
40+
_import_structure["karras_sigmas"] = ["KarrasSigmas"]
41+
42+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
43+
try:
44+
if not is_torch_available():
45+
raise OptionalDependencyNotAvailable()
46+
47+
except OptionalDependencyNotAvailable:
48+
from ...utils.dummy_pt_objects import * # noqa F403
49+
else:
50+
from .beta_sigmas import BetaSigmas
51+
from .exponential_sigmas import ExponentialSigmas
52+
from .karras_sigmas import KarrasSigmas
53+
54+
55+
else:
56+
import sys
57+
58+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
59+
for name, value in _dummy_objects.items():
60+
setattr(sys.modules[__name__], name, value)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import numpy as np
18+
import torch
19+
20+
from ...utils import is_scipy_available
21+
22+
23+
if is_scipy_available():
24+
import scipy.stats
25+
26+
27+
class BetaSigmas:
28+
def __init__(
29+
self,
30+
sigma_min: Optional[float] = None,
31+
sigma_max: Optional[float] = None,
32+
alpha: float = 0.6,
33+
beta: float = 0.6,
34+
):
35+
if not is_scipy_available():
36+
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
37+
self.sigma_min = sigma_min
38+
self.sigma_max = sigma_max
39+
self.alpha = alpha
40+
self.beta = beta
41+
42+
def __call__(self, in_sigmas: torch.Tensor):
43+
sigma_min = self.sigma_min
44+
if sigma_min is None:
45+
sigma_min = in_sigmas[-1].item()
46+
sigma_max = self.sigma_max
47+
if sigma_max is None:
48+
sigma_max = in_sigmas[0].item()
49+
50+
num_inference_steps = len(in_sigmas)
51+
52+
alpha = self.alpha
53+
beta = self.beta
54+
55+
sigmas = np.array(
56+
[
57+
sigma_min + (ppf * (sigma_max - sigma_min))
58+
for ppf in [
59+
scipy.stats.beta.ppf(timestep, alpha, beta)
60+
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
61+
]
62+
]
63+
)
64+
return sigmas
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import Optional
17+
18+
import numpy as np
19+
import torch
20+
21+
22+
class ExponentialSigmas:
23+
def __init__(
24+
self,
25+
sigma_min: Optional[float] = None,
26+
sigma_max: Optional[float] = None,
27+
):
28+
self.sigma_min = sigma_min
29+
self.sigma_max = sigma_max
30+
31+
def __call__(self, in_sigmas: torch.Tensor):
32+
sigma_min = self.sigma_min
33+
if sigma_min is None:
34+
sigma_min = in_sigmas[-1].item()
35+
sigma_max = self.sigma_max
36+
if sigma_max is None:
37+
sigma_max = in_sigmas[0].item()
38+
39+
num_inference_steps = len(in_sigmas)
40+
41+
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
42+
return sigmas

0 commit comments

Comments
 (0)