Skip to content

Commit 630a823

Browse files
committed
adding more noise schedules
1 parent 549a055 commit 630a823

File tree

1 file changed

+217
-54
lines changed

1 file changed

+217
-54
lines changed

bayesflow/experimental/diffusion_model.py

Lines changed: 217 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from bayesflow.types import Tensor, Shape
77
import bayesflow as bf
88
from bayesflow.networks import InferenceNetwork
9+
import math
910

1011
from bayesflow.utils import (
1112
expand_right_as,
@@ -21,9 +22,13 @@
2122

2223
@serializable(package="bayesflow.networks")
2324
class DiffusionModel(InferenceNetwork):
24-
"""Diffusion Model as described as Elucidated Diffusion Model in [1].
25+
"""Diffusion Model as described in this overview paper [1].
26+
27+
[1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data
28+
Augmentation: Kingma et al. (2023)
29+
[2] Score-Based Generative Modeling through Stochastic Differential Equations: Song et al. (2021)
30+
[3] Elucidating the Design Space of Diffusion-Based Generative Models: arXiv:2206.00364
2531
26-
[1] Elucidating the Design Space of Diffusion-Based Generative Models: arXiv:2206.00364
2732
"""
2833

2934
MLP_DEFAULT_CONFIG = {
@@ -74,16 +79,34 @@ def __init__(
7479

7580
super().__init__(base_distribution=None, **keras_kwargs(kwargs))
7681

82+
# todo: clean up these configurations
83+
# EDM hyper-parameters
7784
# internal tunable parameters not intended to be modified by the average user
7885
self.max_sigma = kwargs.get("max_sigma", 80.0)
7986
self.min_sigma = kwargs.get("min_sigma", 1e-4)
8087
self.rho = kwargs.get("rho", 7)
8188
# hyper-parameters for sampling the noise level
8289
self.p_mean = kwargs.get("p_mean", -1.2)
8390
self.p_std = kwargs.get("p_std", 1.2)
91+
self._noise_schedule = kwargs.get("noise_schedule", "EDM")
92+
93+
# general hyper-parameters
94+
self._train_time = kwargs.get("train_time", "continuous")
95+
self._timesteps = kwargs.get("timesteps", None)
96+
if self._train_time == "discrete":
97+
if not isinstance(self._timesteps, int):
98+
raise ValueError('timesteps must be defined, if "discrete" training time is set')
99+
self._loss_type = kwargs.get("loss_type", "eps")
100+
self._weighting_function = kwargs.get("weighting_function", None)
101+
self._log_snr_min = kwargs.get("log_snr_min", -15)
102+
self._log_snr_max = kwargs.get("log_snr_max", 15)
103+
self._t_min = self._get_t_from_log_snr(log_snr_t=self._log_snr_max)
104+
self._t_max = self._get_t_from_log_snr(log_snr_t=self._log_snr_min)
105+
self._s_shift_cosine = kwargs.get("s_shift_cosine", 0.0)
84106

85107
# latent distribution (not configurable)
86108
self.base_distribution = bf.distributions.DiagonalNormal(mean=0.0, std=self.max_sigma)
109+
87110
self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
88111

89112
self.sigma_data = sigma_data
@@ -142,51 +165,62 @@ def _c_in_fn(self, sigma):
142165
return 1.0 / ops.sqrt(sigma**2 + self.sigma_data**2)
143166

144167
def _c_noise_fn(self, sigma):
145-
return 0.25 * ops.log(sigma)
146-
147-
def _denoiser_fn(
148-
self,
149-
xz: Tensor,
150-
sigma: Tensor,
151-
conditions: Tensor = None,
152-
training: bool = False,
153-
):
154-
# calculate output of the network
155-
c_in = self._c_in_fn(sigma)
156-
c_noise = self._c_noise_fn(sigma)
157-
xz_pre = c_in * xz
158-
if conditions is None:
159-
xtc = keras.ops.concatenate([xz_pre, c_noise], axis=-1)
160-
else:
161-
xtc = keras.ops.concatenate([xz_pre, c_noise, conditions], axis=-1)
162-
out = self.output_projector(self.subnet(xtc, training=training), training=training)
163-
return self._c_skip_fn(sigma) * xz + self._c_out_fn(sigma) * out
168+
return 0.25 * ops.log(sigma) # this is the snr times a constant
164169

165170
def velocity(
166171
self,
167172
xz: Tensor,
168-
sigma: float | Tensor,
173+
time: float | Tensor,
169174
conditions: Tensor = None,
170175
training: bool = False,
176+
clip_x: bool = True,
171177
) -> Tensor:
172-
# transform sigma vector into correct shape
173-
sigma = keras.ops.convert_to_tensor(sigma, dtype=keras.ops.dtype(xz))
174-
sigma = expand_right_as(sigma, xz)
175-
sigma = keras.ops.broadcast_to(sigma, keras.ops.shape(xz)[:-1] + (1,))
178+
# calculate the current noise level and transform into correct shape
179+
log_snr_t = expand_right_as(self._get_log_snr(t=time), xz)
180+
alpha_t, sigma_t = self._get_alpha_sigma(log_snr_t=log_snr_t)
176181

177-
d = self._denoiser_fn(xz, sigma, conditions, training=training)
178-
return (xz - d) / sigma
182+
if self._noise_schedule == "EDM":
183+
# scale the input
184+
xz = alpha_t * xz
185+
186+
if conditions is None:
187+
xtc = keras.ops.concatenate([xz, log_snr_t], axis=-1)
188+
else:
189+
xtc = keras.ops.concatenate([xz, log_snr_t, conditions], axis=-1)
190+
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
191+
192+
if self._noise_schedule == "EDM":
193+
# scale the output
194+
s = ops.exp(-1 / 2 * log_snr_t)
195+
pred_scaled = self._c_skip_fn(s) * xz + self._c_out_fn(s) * pred
196+
out = (xz - pred_scaled) / s
197+
else:
198+
# first convert prediction to x-prediction
199+
if self._loss_type == "eps":
200+
x_pred = (xz - sigma_t * pred) / alpha_t
201+
else: # self._loss_type == 'v':
202+
x_pred = alpha_t * xz - sigma_t * pred
203+
204+
# clip x if necessary
205+
if clip_x:
206+
x_pred = ops.clip(x_pred, -5, 5)
207+
# convert x to score
208+
score = (alpha_t * x_pred - xz) / ops.square(sigma_t)
209+
# compute velocity for the ODE depending on the noise schedule
210+
f, g = self._get_drift_diffusion(log_snr_t=log_snr_t, x=xz)
211+
out = f - 0.5 * ops.square(g) * score
212+
return out
179213

180214
def _velocity_trace(
181215
self,
182216
xz: Tensor,
183-
sigma: Tensor,
217+
time: Tensor,
184218
conditions: Tensor = None,
185219
max_steps: int = None,
186220
training: bool = False,
187221
) -> (Tensor, Tensor):
188222
def f(x):
189-
return self.velocity(x, sigma=sigma, conditions=conditions, training=training)
223+
return self.velocity(x, time=time, conditions=conditions, training=training)
190224

191225
v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True)
192226

@@ -207,7 +241,7 @@ def _forward(
207241
if density:
208242

209243
def deltas(time, xz):
210-
v, trace = self._velocity_trace(xz, sigma=time, conditions=conditions, training=training)
244+
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
211245
return {"xz": v, "trace": trace}
212246

213247
state = {
@@ -226,7 +260,7 @@ def deltas(time, xz):
226260
return z, log_density
227261

228262
def deltas(time, xz):
229-
return {"xz": self.velocity(xz, sigma=time, conditions=conditions, training=training)}
263+
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
230264

231265
state = {"xz": x}
232266
state = integrate(
@@ -256,7 +290,7 @@ def _inverse(
256290
if density:
257291

258292
def deltas(time, xz):
259-
v, trace = self._velocity_trace(xz, sigma=time, conditions=conditions, training=training)
293+
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
260294
return {"xz": v, "trace": trace}
261295

262296
state = {
@@ -271,7 +305,7 @@ def deltas(time, xz):
271305
return x, log_density
272306

273307
def deltas(time, xz):
274-
return {"xz": self.velocity(xz, sigma=time, conditions=conditions, training=training)}
308+
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
275309

276310
state = {"xz": z}
277311
state = integrate(
@@ -284,6 +318,120 @@ def deltas(time, xz):
284318

285319
return x
286320

321+
def _get_drift_diffusion(self, log_snr_t, x=None): # t is not truncated
322+
"""
323+
Compute d/dt log(1 + e^(-snr(t))) for the truncated schedules.
324+
"""
325+
t = self._get_t_from_log_snr(log_snr_t=log_snr_t)
326+
# Compute the truncated time t_trunc
327+
t_trunc = self._t_min + (self._t_max - self._t_min) * t
328+
329+
# Compute d/dx snr(x) based on the noise schedule
330+
if self._noise_schedule == "linear":
331+
# d/dx snr(x) = - 2*x*exp(x^2) / (exp(x^2) - 1)
332+
dsnr_dx = -(2 * t_trunc * ops.exp(t_trunc**2)) / (ops.exp(t_trunc**2) - 1)
333+
elif self._noise_schedule == "cosine":
334+
# d/dx snr(x) = -2*pi/sin(pi*x)
335+
dsnr_dx = -(2 * math.pi) / ops.sin(math.pi * t_trunc)
336+
elif self._noise_schedule == "flow_matching":
337+
# d/dx snr(x) = -2/(x*(1-x))
338+
dsnr_dx = -2 / (t_trunc * (1 - t_trunc))
339+
else:
340+
raise ValueError("Invalid 'noise_schedule'.")
341+
342+
# Chain rule: d/dt snr(t) = d/dx snr(x) * (t_max - t_min)
343+
dsnr_dt = dsnr_dx * (self._t_max - self._t_min)
344+
345+
# Using the chain rule on f(t) = log(1 + e^(-snr(t))):
346+
# f'(t) = - (e^{-snr(t)} / (1 + e^{-snr(t)})) * dsnr_dt
347+
factor = ops.exp(-log_snr_t) / (1 + ops.exp(-log_snr_t))
348+
349+
beta_t = -factor * dsnr_dt
350+
g = ops.sqrt(beta_t) # diffusion term
351+
if x is None:
352+
return g
353+
f = -0.5 * beta_t * x # drift term
354+
return f, g
355+
356+
def _get_log_snr(self, t: Tensor) -> Tensor:
357+
"""get the log signal-to-noise ratio (lambda) for a given diffusion time"""
358+
if self._noise_schedule == "EDM":
359+
# EDM defines tilde sigma ~ N(p_mean, p_std^2)
360+
# tilde sigma^2 = exp(-lambda), hence lambda = -2 * log(sigma)
361+
# sample noise
362+
log_sigma_tilde = self.p_mean + self.p_std * keras.random.normal(
363+
ops.shape(t), dtype=ops.dtype(t), seed=self.seed_generator
364+
)
365+
# calculate the log signal-to-noise ratio
366+
log_snr_t = -2 * log_sigma_tilde
367+
return log_snr_t
368+
369+
t_trunc = self._t_min + (self._t_max - self._t_min) * t
370+
if self._noise_schedule == "linear":
371+
log_snr_t = -ops.log(ops.exp(ops.square(t_trunc)) - 1)
372+
elif self._noise_schedule == "cosine": # this is usually used with variance_preserving
373+
log_snr_t = -2 * ops.log(ops.tan(math.pi * t_trunc / 2)) + 2 * self._s_shift_cosine
374+
elif self._noise_schedule == "flow_matching": # this usually used with sub_variance_preserving
375+
log_snr_t = 2 * ops.log((1 - t_trunc) / t_trunc)
376+
else:
377+
raise ValueError("Unknown noise schedule: {}".format(self._noise_schedule))
378+
return log_snr_t
379+
380+
def _get_t_from_log_snr(self, log_snr_t) -> Tensor:
381+
# Invert the noise scheduling to recover t (not truncated)
382+
if self._noise_schedule == "linear":
383+
# SNR = -log(exp(t^2) - 1)
384+
# => t = sqrt(log(1 + exp(-snr)))
385+
t = ops.sqrt(ops.log(1 + ops.exp(-log_snr_t)))
386+
elif self._noise_schedule == "cosine":
387+
# SNR = -2 * log(tan(pi*t/2))
388+
# => t = 2/pi * arctan(exp(-snr/2))
389+
t = 2 / math.pi * ops.arctan(ops.exp((2 * self._s_shift_cosine - log_snr_t) / 2))
390+
elif self._noise_schedule == "flow_matching":
391+
# SNR = 2 * log((1-t)/t)
392+
# => t = 1 / (1 + exp(snr/2))
393+
t = 1 / (1 + ops.exp(log_snr_t / 2))
394+
elif self._noise_schedule == "EDM":
395+
raise NotImplementedError
396+
else:
397+
raise ValueError("Unknown noise schedule: {}".format(self._noise_schedule))
398+
return t
399+
400+
def _get_alpha_sigma(self, log_snr_t: Tensor) -> tuple[Tensor, Tensor]:
401+
if self._noise_schedule == "EDM":
402+
# EDM: noisy_x = c_in * (x + s * e) = c_in * x + c_in * s * e
403+
# s^2 = exp(-lambda)
404+
s = ops.exp(-1 / 2 * log_snr_t)
405+
c_in = self._c_in_fn(s)
406+
407+
# alpha = c_in(s), sigma = c_in * s
408+
alpha_t = c_in
409+
sigma_t = c_in * s
410+
else:
411+
# variance preserving noise schedules
412+
alpha_t = keras.ops.sqrt(keras.ops.sigmoid(log_snr_t))
413+
sigma_t = keras.ops.sqrt(keras.ops.sigmoid(-log_snr_t))
414+
return alpha_t, sigma_t
415+
416+
def _get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
417+
if self._noise_schedule == "EDM":
418+
# EDM: weights are constructed elsewhere
419+
weights = ops.ones_like(log_snr_t)
420+
return weights
421+
422+
if self._weighting_function == "likelihood_weighting": # based on Song et al. (2021)
423+
g_t = self._get_drift_diffusion(log_snr_t=log_snr_t)
424+
sigma_t = self._get_alpha_sigma(log_snr_t=log_snr_t)[1]
425+
weights = ops.square(g_t / sigma_t)
426+
elif self._weighting_function == "sigmoid": # based on Kingma et al. (2023)
427+
weights = ops.sigmoid(-log_snr_t / 2)
428+
elif self._weighting_function == "min-snr": # based on Hang et al. (2023)
429+
gamma = 5
430+
weights = 1 / ops.cosh(log_snr_t / 2) * ops.minimum(ops.ones_like(log_snr_t), gamma * ops.exp(-log_snr_t))
431+
else:
432+
weights = ops.ones_like(log_snr_t)
433+
return weights
434+
287435
def compute_metrics(
288436
self,
289437
x: Tensor | Sequence[Tensor, ...],
@@ -297,36 +445,51 @@ def compute_metrics(
297445
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
298446
self.build(xz_shape, conditions_shape)
299447

300-
# sample log-noise level
301-
log_sigma = self.p_mean + self.p_std * keras.random.normal(
302-
ops.shape(x)[:1], dtype=ops.dtype(x), seed=self.seed_generator
303-
)
304-
# noise level with shape (batch_size, 1)
305-
sigma = ops.exp(log_sigma)[:, None]
448+
# sample training diffusion time
449+
if self._train_time == "continuous":
450+
t = keras.random.uniform((keras.ops.shape(x)[0],))
451+
elif self._train_time == "discrete":
452+
i = keras.random.randint((keras.ops.shape(x)[0],), minval=0, maxval=self._timesteps)
453+
t = keras.ops.cast(i, keras.ops.dtype(x)) / keras.ops.cast(self._timesteps, keras.ops.dtype(x))
454+
else:
455+
raise NotImplementedError(f"Training time {self._train_time} not implemented")
456+
457+
# calculate the noise level
458+
log_snr_t = expand_right_as(self._get_log_snr(t), x)
459+
alpha_t, sigma_t = self._get_alpha_sigma(log_snr_t=log_snr_t)
306460

307461
# generate noise vector
308-
z = sigma * keras.random.normal(ops.shape(x), dtype=ops.dtype(x), seed=self.seed_generator)
462+
eps_t = keras.random.normal(ops.shape(x), dtype=ops.dtype(x), seed=self.seed_generator)
309463

310-
# calculate preconditioning
311-
c_skip = self._c_skip_fn(sigma)
312-
c_out = self._c_out_fn(sigma)
313-
c_in = self._c_in_fn(sigma)
314-
c_noise = self._c_noise_fn(sigma)
315-
xz_pre = c_in * (x + z)
464+
# diffuse x
465+
diffused_x = alpha_t * x + sigma_t * eps_t
316466

317467
# calculate output of the network
318468
if conditions is None:
319-
xtc = keras.ops.concatenate([xz_pre, c_noise], axis=-1)
469+
xtc = keras.ops.concatenate([diffused_x, log_snr_t], axis=-1)
320470
else:
321-
xtc = keras.ops.concatenate([xz_pre, c_noise, conditions], axis=-1)
471+
xtc = keras.ops.concatenate([diffused_x, log_snr_t, conditions], axis=-1)
322472

323473
out = self.output_projector(self.subnet(xtc, training=training), training=training)
324474

325-
# Calculate loss:
326-
lam = 1 / c_out[:, 0] ** 2
327-
effective_weight = lam * c_out[:, 0] ** 2
328-
unweighted_loss = ops.mean((out - 1 / c_out * (x - c_skip * (x + z))) ** 2, axis=-1)
329-
loss = effective_weight * unweighted_loss
475+
# Calculate loss
476+
weights_for_snr = self._get_weights_for_snr(log_snr_t=log_snr_t)
477+
if self._loss_type == "eps":
478+
loss = weights_for_snr * ops.mean((out - eps_t) ** 2, axis=-1)
479+
elif self._loss_type == "v":
480+
v_t = alpha_t * eps_t - sigma_t * x
481+
loss = weights_for_snr * ops.mean((out - v_t) ** 2, axis=-1)
482+
elif self._loss_type == "EDM":
483+
s = ops.exp(-1 / 2 * log_snr_t)
484+
c_skip = self._c_skip_fn(s)
485+
c_out = self._c_out_fn(s)
486+
lam = 1 / c_out[:, 0] ** 2
487+
effective_weight = lam * c_out[:, 0] ** 2
488+
unweighted_loss = ops.mean((out - 1 / c_out * (x - c_skip * (x + s + eps_t))) ** 2, axis=-1)
489+
loss = effective_weight * unweighted_loss
490+
else:
491+
raise ValueError(f"Unknown loss type: {self._loss_type}")
492+
330493
loss = weighted_mean(loss, sample_weight)
331494

332495
base_metrics = super().compute_metrics(x, conditions, sample_weight, stage)

0 commit comments

Comments
 (0)