Skip to content

Commit 0f96265

Browse files
committed
base class noise schedule
1 parent 3ee8582 commit 0f96265

File tree

2 files changed

+164
-156
lines changed

2 files changed

+164
-156
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Union, Literal
3+
4+
from keras import ops
5+
6+
from bayesflow.types import Tensor
7+
from bayesflow.utils.serialization import deserialize, serializable
8+
9+
10+
# disable module check, use potential module after moving from experimental
11+
@serializable("bayesflow.networks", disable_module_check=True)
12+
class NoiseSchedule(ABC):
13+
r"""Noise schedule for diffusion models. We follow the notation from [1].
14+
15+
The diffusion process is defined by a noise schedule, which determines how the noise level changes over time.
16+
We define the noise schedule as a function of the log signal-to-noise ratio (lambda), which can be
17+
interchangeably used with the diffusion time (t).
18+
19+
The noise process is defined as: z = alpha(t) * x + sigma(t) * e, where e ~ N(0, I).
20+
The schedule is defined as: \lambda(t) = \log \sigma^2(t) - \log \alpha^2(t).
21+
22+
We can also define a weighting function for each noise level for the loss function. Often the noise schedule is
23+
the same for the forward and reverse process, but this is not necessary and can be changed via the training flag.
24+
25+
[1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data
26+
Augmentation: Kingma et al. (2023)
27+
"""
28+
29+
def __init__(
30+
self,
31+
name: str,
32+
variance_type: Literal["preserving", "exploding"],
33+
weighting: Literal["sigmoid", "likelihood_weighting"] = None,
34+
):
35+
"""
36+
Initialize the noise schedule.
37+
38+
Parameters
39+
----------
40+
name : str
41+
The name of the noise schedule.
42+
variance_type : Literal["preserving", "exploding"]
43+
If the variance of noise added to the data should be preserved over time, use "preserving".
44+
If the variance of noise added to the data should increase over time, use "exploding".
45+
Default is "preserving".
46+
weighting : Literal["sigmoid", "likelihood_weighting"], optional
47+
The type of weighting function to use for the noise schedule.
48+
Default is None, which means no weighting is applied.
49+
"""
50+
self.name = name
51+
self._variance_type = variance_type
52+
self.log_snr_min = None # should be set in the subclasses
53+
self.log_snr_max = None # should be set in the subclasses
54+
self._weighting = weighting
55+
56+
@abstractmethod
57+
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
58+
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
59+
pass
60+
61+
@abstractmethod
62+
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
63+
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
64+
pass
65+
66+
@abstractmethod
67+
def derivative_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
68+
r"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE."""
69+
pass
70+
71+
def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = False) -> tuple[Tensor, Tensor]:
72+
r"""Compute the drift and optionally the squared diffusion term for the reverse SDE.
73+
It can be derived from the derivative of the schedule:
74+
75+
math::
76+
\beta(t) = d/dt \log(1 + e^{-snr(t)})
77+
78+
f(z, t) = -0.5 * \beta(t) * z
79+
80+
g(t)^2 = \beta(t)
81+
82+
The corresponding differential equations are::
83+
84+
SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
85+
ODE: dz = [ f(z, t) - 0.5 * g(t)^2 * score(z, lambda) ] dt
86+
87+
For a variance exploding schedule, one should set f(z, t) = 0.
88+
"""
89+
beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training)
90+
if x is None: # return g^2 only
91+
return beta
92+
if self._variance_type == "preserving":
93+
f = -0.5 * beta * x
94+
elif self._variance_type == "exploding":
95+
f = ops.zeros_like(beta)
96+
else:
97+
raise ValueError(f"Unknown variance type: {self._variance_type}")
98+
return f, beta
99+
100+
def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Tensor]:
101+
"""Get alpha and sigma for a given log signal-to-noise ratio (lambda).
102+
103+
Default is a variance preserving schedule::
104+
105+
alpha(t) = sqrt(sigmoid(log_snr_t))
106+
sigma(t) = sqrt(sigmoid(-log_snr_t))
107+
108+
For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)
109+
"""
110+
if self._variance_type == "preserving":
111+
# variance preserving schedule
112+
alpha_t = ops.sqrt(ops.sigmoid(log_snr_t))
113+
sigma_t = ops.sqrt(ops.sigmoid(-log_snr_t))
114+
elif self._variance_type == "exploding":
115+
# variance exploding schedule
116+
alpha_t = ops.ones_like(log_snr_t)
117+
sigma_t = ops.sqrt(ops.exp(-log_snr_t))
118+
else:
119+
raise TypeError(f"Unknown variance type: {self._variance_type}")
120+
return alpha_t, sigma_t
121+
122+
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
123+
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
124+
Default weighting is None, which means only ones are returned.
125+
Generally, weighting functions should be defined for a noise prediction loss.
126+
"""
127+
if self._weighting is None:
128+
return ops.ones_like(log_snr_t)
129+
elif self._weighting == "sigmoid":
130+
# sigmoid weighting based on Kingma et al. (2023)
131+
return ops.sigmoid(-log_snr_t + 2)
132+
elif self._weighting == "likelihood_weighting":
133+
# likelihood weighting based on Song et al. (2021)
134+
g_squared = self.get_drift_diffusion(log_snr_t=log_snr_t)
135+
sigma_t = self.get_alpha_sigma(log_snr_t=log_snr_t, training=True)[1]
136+
return g_squared / ops.square(sigma_t)
137+
else:
138+
raise TypeError(f"Unknown weighting type: {self._weighting}")
139+
140+
def get_config(self):
141+
return dict(name=self.name, variance_type=self._variance_type)
142+
143+
@classmethod
144+
def from_config(cls, config, custom_objects=None):
145+
return cls(**deserialize(config, custom_objects=custom_objects))
146+
147+
def validate(self):
148+
"""Validate the noise schedule."""
149+
if self.log_snr_min >= self.log_snr_max:
150+
raise ValueError("min_log_snr must be less than max_log_snr.")
151+
for training in [True, False]:
152+
if not ops.isfinite(self.get_log_snr(0.0, training=training)):
153+
raise ValueError(f"log_snr(0) must be finite with training={training}.")
154+
if not ops.isfinite(self.get_log_snr(1.0, training=training)):
155+
raise ValueError(f"log_snr(1) must be finite with training={training}.")
156+
if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_max, training=training)):
157+
raise ValueError(f"t(0) must be finite with training={training}.")
158+
if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_min, training=training)):
159+
raise ValueError(f"t(1) must be finite with training={training}.")
160+
if not ops.isfinite(self.derivative_log_snr(self.log_snr_max, training=False)):
161+
raise ValueError("dt/t log_snr(0) must be finite.")
162+
if not ops.isfinite(self.derivative_log_snr(self.log_snr_min, training=False)):
163+
raise ValueError("dt/t log_snr(1) must be finite.")

bayesflow/experimental/diffusion_model/noise_schedules.py

Lines changed: 1 addition & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -1,167 +1,12 @@
11
import math
2-
from abc import ABC, abstractmethod
32
from typing import Union, Literal
43

54
from keras import ops
65

76
from bayesflow.types import Tensor
87
from bayesflow.utils.serialization import deserialize, serializable
98

10-
11-
# disable module check, use potential module after moving from experimental
12-
@serializable("bayesflow.networks", disable_module_check=True)
13-
class NoiseSchedule(ABC):
14-
r"""Noise schedule for diffusion models. We follow the notation from [1].
15-
16-
The diffusion process is defined by a noise schedule, which determines how the noise level changes over time.
17-
We define the noise schedule as a function of the log signal-to-noise ratio (lambda), which can be
18-
interchangeably used with the diffusion time (t).
19-
20-
The noise process is defined as: z = alpha(t) * x + sigma(t) * e, where e ~ N(0, I).
21-
The schedule is defined as: \lambda(t) = \log \sigma^2(t) - \log \alpha^2(t).
22-
23-
We can also define a weighting function for each noise level for the loss function. Often the noise schedule is
24-
the same for the forward and reverse process, but this is not necessary and can be changed via the training flag.
25-
26-
[1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data
27-
Augmentation: Kingma et al. (2023)
28-
"""
29-
30-
def __init__(
31-
self,
32-
name: str,
33-
variance_type: Literal["preserving", "exploding"],
34-
weighting: Literal["sigmoid", "likelihood_weighting"] = None,
35-
):
36-
"""
37-
Initialize the noise schedule.
38-
39-
Parameters
40-
----------
41-
name : str
42-
The name of the noise schedule.
43-
variance_type : Literal["preserving", "exploding"]
44-
If the variance of noise added to the data should be preserved over time, use "preserving".
45-
If the variance of noise added to the data should increase over time, use "exploding".
46-
Default is "preserving".
47-
weighting : Literal["sigmoid", "likelihood_weighting"], optional
48-
The type of weighting function to use for the noise schedule.
49-
Default is None, which means no weighting is applied.
50-
"""
51-
self.name = name
52-
self._variance_type = variance_type
53-
self.log_snr_min = None # should be set in the subclasses
54-
self.log_snr_max = None # should be set in the subclasses
55-
self._weighting = weighting
56-
57-
@abstractmethod
58-
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
59-
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
60-
pass
61-
62-
@abstractmethod
63-
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
64-
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
65-
pass
66-
67-
@abstractmethod
68-
def derivative_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
69-
r"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE."""
70-
pass
71-
72-
def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = False) -> tuple[Tensor, Tensor]:
73-
r"""Compute the drift and optionally the squared diffusion term for the reverse SDE.
74-
It can be derived from the derivative of the schedule:
75-
76-
.. math::
77-
\beta(t) = d/dt \log(1 + e^{-snr(t)})
78-
79-
f(z, t) = -0.5 * \beta(t) * z
80-
81-
g(t)^2 = \beta(t)
82-
83-
The corresponding differential equations are::
84-
85-
SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
86-
ODE: dz = [ f(z, t) - 0.5 * g(t)^2 * score(z, lambda) ] dt
87-
88-
For a variance exploding schedule, one should set f(z, t) = 0.
89-
"""
90-
beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training)
91-
if x is None: # return g^2 only
92-
return beta
93-
if self._variance_type == "preserving":
94-
f = -0.5 * beta * x
95-
elif self._variance_type == "exploding":
96-
f = ops.zeros_like(beta)
97-
else:
98-
raise ValueError(f"Unknown variance type: {self._variance_type}")
99-
return f, beta
100-
101-
def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Tensor]:
102-
"""Get alpha and sigma for a given log signal-to-noise ratio (lambda).
103-
104-
Default is a variance preserving schedule::
105-
106-
alpha(t) = sqrt(sigmoid(log_snr_t))
107-
sigma(t) = sqrt(sigmoid(-log_snr_t))
108-
109-
For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)
110-
"""
111-
if self._variance_type == "preserving":
112-
# variance preserving schedule
113-
alpha_t = ops.sqrt(ops.sigmoid(log_snr_t))
114-
sigma_t = ops.sqrt(ops.sigmoid(-log_snr_t))
115-
elif self._variance_type == "exploding":
116-
# variance exploding schedule
117-
alpha_t = ops.ones_like(log_snr_t)
118-
sigma_t = ops.sqrt(ops.exp(-log_snr_t))
119-
else:
120-
raise TypeError(f"Unknown variance type: {self._variance_type}")
121-
return alpha_t, sigma_t
122-
123-
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
124-
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
125-
Default weighting is None, which means only ones are returned.
126-
Generally, weighting functions should be defined for a noise prediction loss.
127-
"""
128-
if self._weighting is None:
129-
return ops.ones_like(log_snr_t)
130-
elif self._weighting == "sigmoid":
131-
# sigmoid weighting based on Kingma et al. (2023)
132-
return ops.sigmoid(-log_snr_t + 2)
133-
elif self._weighting == "likelihood_weighting":
134-
# likelihood weighting based on Song et al. (2021)
135-
g_squared = self.get_drift_diffusion(log_snr_t=log_snr_t)
136-
sigma_t = self.get_alpha_sigma(log_snr_t=log_snr_t, training=True)[1]
137-
return g_squared / ops.square(sigma_t)
138-
else:
139-
raise TypeError(f"Unknown weighting type: {self._weighting}")
140-
141-
def get_config(self):
142-
return dict(name=self.name, variance_type=self._variance_type)
143-
144-
@classmethod
145-
def from_config(cls, config, custom_objects=None):
146-
return cls(**deserialize(config, custom_objects=custom_objects))
147-
148-
def validate(self):
149-
"""Validate the noise schedule."""
150-
if self.log_snr_min >= self.log_snr_max:
151-
raise ValueError("min_log_snr must be less than max_log_snr.")
152-
for training in [True, False]:
153-
if not ops.isfinite(self.get_log_snr(0.0, training=training)):
154-
raise ValueError(f"log_snr(0) must be finite with training={training}.")
155-
if not ops.isfinite(self.get_log_snr(1.0, training=training)):
156-
raise ValueError(f"log_snr(1) must be finite with training={training}.")
157-
if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_max, training=training)):
158-
raise ValueError(f"t(0) must be finite with training={training}.")
159-
if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_min, training=training)):
160-
raise ValueError(f"t(1) must be finite with training={training}.")
161-
if not ops.isfinite(self.derivative_log_snr(self.log_snr_max, training=False)):
162-
raise ValueError("dt/t log_snr(0) must be finite.")
163-
if not ops.isfinite(self.derivative_log_snr(self.log_snr_min, training=False)):
164-
raise ValueError("dt/t log_snr(1) must be finite.")
9+
from .noise_schedule_base import NoiseSchedule
16510

16611

16712
@serializable("bayesflow.experimental")

0 commit comments

Comments
 (0)