Skip to content

Commit 0b5a800

Browse files
committed
docstring formatting
1 parent d43ea98 commit 0b5a800

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class DiffusionModel(InferenceNetwork):
2727
"""Diffusion Model as described in this overview paper [1].
2828
2929
[1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data
30-
Augmentation: Kingma et al. (2023)
30+
Augmentation: Kingma et al. (2023)
31+
3132
[2] Score-Based Generative Modeling through Stochastic Differential Equations: Song et al. (2021)
3233
"""
3334

bayesflow/experimental/diffusion_model/noise_schedules.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class NoiseSchedule(ABC):
2424
the same for the forward and reverse process, but this is not necessary and can be changed via the training flag.
2525
2626
[1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data
27-
Augmentation: Kingma et al. (2023)
27+
Augmentation: Kingma et al. (2023)
2828
"""
2929

3030
def __init__(
@@ -72,10 +72,16 @@ def derivative_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) ->
7272
def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = False) -> tuple[Tensor, Tensor]:
7373
r"""Compute the drift and optionally the squared diffusion term for the reverse SDE.
7474
It can be derived from the derivative of the schedule:
75-
\beta(t) = d/dt log(1 + e^(-snr(t)))
75+
76+
.. math::
77+
\beta(t) = d/dt \log(1 + e^{-snr(t)})
78+
7679
f(z, t) = -0.5 * \beta(t) * z
80+
7781
g(t)^2 = \beta(t)
7882
83+
The corresponding differential equations are::
84+
7985
SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
8086
ODE: dz = [ f(z, t) - 0.5 * g(t)^2 * score(z, lambda) ] dt
8187
@@ -95,9 +101,11 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
95101
def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Tensor]:
96102
"""Get alpha and sigma for a given log signal-to-noise ratio (lambda).
97103
98-
Default is a variance preserving schedule:
104+
Default is a variance preserving schedule::
105+
99106
alpha(t) = sqrt(sigmoid(log_snr_t))
100107
sigma(t) = sqrt(sigmoid(-log_snr_t))
108+
101109
For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)
102110
"""
103111
if self._variance_type == "preserving":

0 commit comments

Comments
 (0)