Skip to content

Commit 739491a

Browse files
committed
improve schedules
1 parent f2d7de4 commit 739491a

File tree

1 file changed

+37
-36
lines changed

1 file changed

+37
-36
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,22 @@ class NoiseSchedule(ABC):
3737
Augmentation: Kingma et al. (2023)
3838
"""
3939

40-
def __init__(self, name: str):
40+
def __init__(self, name: str, variance_type: str):
4141
self.name = name
42-
43-
# for variance preserving schedules
44-
self.scale_base_distribution = 1.0
42+
self.variance_type = variance_type # 'exploding' or 'preserving'
43+
self._log_snr_min = ops.convert_to_tensor(-15) # should be set in the subclasses
44+
self._log_snr_max = ops.convert_to_tensor(15) # should be set in the subclasses
45+
46+
@property
47+
def scale_base_distribution(self):
48+
"""Get the scale of the base distribution."""
49+
if self.variance_type == "preserving":
50+
return 1.0
51+
elif self.variance_type == "exploding":
52+
# e.g., EDM is a variance exploding schedule
53+
return ops.exp(-self._log_snr_min)
54+
else:
55+
raise ValueError(f"Unknown variance type: {self.variance_type}")
4556

4657
@abstractmethod
4758
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
@@ -74,17 +85,32 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
7485
beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training)
7586
if x is None: # return g only
7687
return ops.sqrt(beta)
77-
f = -0.5 * beta * x
88+
if self.variance_type == "preserving":
89+
f = -0.5 * beta * x
90+
elif self.variance_type == "exploding":
91+
f = ops.zeros_like(beta)
92+
else:
93+
raise ValueError(f"Unknown variance type: {self.variance_type}")
7894
return f, ops.sqrt(beta)
7995

8096
def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Tensor]:
8197
"""Get alpha and sigma for a given log signal-to-noise ratio (lambda).
8298
83-
Default is a variance preserving schedule.
99+
Default is a variance preserving schedule:
100+
alpha(t) = sqrt(sigmoid(log_snr_t))
101+
sigma(t) = sqrt(sigmoid(-log_snr_t))
84102
For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)
85103
"""
86-
alpha_t = keras.ops.sqrt(keras.ops.sigmoid(log_snr_t))
87-
sigma_t = keras.ops.sqrt(keras.ops.sigmoid(-log_snr_t))
104+
if self.variance_type == "preserving":
105+
# variance preserving schedule
106+
alpha_t = keras.ops.sqrt(keras.ops.sigmoid(log_snr_t))
107+
sigma_t = keras.ops.sqrt(keras.ops.sigmoid(-log_snr_t))
108+
elif self.variance_type == "exploding":
109+
# variance exploding schedule
110+
alpha_t = ops.ones_like(log_snr_t)
111+
sigma_t = ops.sqrt(ops.exp(-log_snr_t))
112+
else:
113+
raise ValueError(f"Unknown variance type: {self.variance_type}")
88114
return alpha_t, sigma_t
89115

90116
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
@@ -106,7 +132,7 @@ class LinearNoiseSchedule(NoiseSchedule):
106132
"""
107133

108134
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15):
109-
super().__init__(name="linear_noise_schedule")
135+
super().__init__(name="linear_noise_schedule", variance_type="preserving")
110136
self._log_snr_min = ops.convert_to_tensor(min_log_snr)
111137
self._log_snr_max = ops.convert_to_tensor(max_log_snr)
112138

@@ -155,7 +181,7 @@ class CosineNoiseSchedule(NoiseSchedule):
155181
"""
156182

157183
def __init__(self, min_log_snr: float = -15, max_log_snr: float = 15, s_shift_cosine: float = 0.0):
158-
super().__init__(name="cosine_noise_schedule")
184+
super().__init__(name="cosine_noise_schedule", variance_type="preserving")
159185
self._log_snr_min = ops.convert_to_tensor(min_log_snr)
160186
self._log_snr_max = ops.convert_to_tensor(max_log_snr)
161187
self._s_shift_cosine = ops.convert_to_tensor(s_shift_cosine)
@@ -202,7 +228,7 @@ class EDMNoiseSchedule(NoiseSchedule):
202228
"""
203229

204230
def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max: float = 80):
205-
super().__init__(name="edm_noise_schedule")
231+
super().__init__(name="edm_noise_schedule", variance_type="exploding")
206232
self.sigma_data = ops.convert_to_tensor(sigma_data)
207233
self.sigma_max = ops.convert_to_tensor(sigma_max)
208234
self.sigma_min = ops.convert_to_tensor(sigma_min)
@@ -216,9 +242,6 @@ def __init__(self, sigma_data: float = 0.5, sigma_min: float = 0.002, sigma_max:
216242
self._t_min = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
217243
self._t_max = self.get_t_from_log_snr(log_snr_t=self._log_snr_max, training=True)
218244

219-
# EDM is a variance exploding schedule
220-
self.scale_base_distribution = ops.exp(-self._log_snr_min)
221-
222245
def get_log_snr(self, t: Tensor, training: bool) -> Tensor:
223246
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
224247
t_trunc = self._t_min + (self._t_max - self._t_min) * t
@@ -278,28 +301,6 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
278301
factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t))
279302
return -factor * dsnr_dt
280303

281-
def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = True) -> tuple[Tensor, Tensor]:
282-
"""Compute the drift and optionally the diffusion term for the variance exploding reverse SDE.
283-
\beta(t) = d/dt log(1 + e^(-snr(t)))
284-
f(z, t) = 0
285-
g(t)^2 = \beta(t)
286-
287-
SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
288-
ODE: dz = [ f(z, t) - 0.5 * g(t)^2 * score(z, lambda) ] dt
289-
"""
290-
# Default implementation is to return the diffusion term only
291-
beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training)
292-
if x is None: # return g only
293-
return ops.sqrt(beta)
294-
f = ops.zeros_like(beta) # variance exploding schedule
295-
return f, ops.sqrt(beta)
296-
297-
def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Tensor]:
298-
"""Get alpha and sigma for a given log signal-to-noise ratio (lambda) for a variance exploding schedule."""
299-
alpha_t = ops.ones_like(log_snr_t)
300-
sigma_t = ops.sqrt(ops.exp(-log_snr_t))
301-
return alpha_t, sigma_t
302-
303304
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
304305
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
305306
return ops.exp(-log_snr_t) + 0.5**2

0 commit comments

Comments
 (0)