@@ -24,7 +24,7 @@ class NoiseSchedule(ABC):
2424 the same for the forward and reverse process, but this is not necessary and can be changed via the training flag.
2525
2626 [1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data
27- Augmentation: Kingma et al. (2023)
27+ Augmentation: Kingma et al. (2023)
2828 """
2929
3030 def __init__ (
@@ -72,10 +72,16 @@ def derivative_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) ->
7272 def get_drift_diffusion (self , log_snr_t : Tensor , x : Tensor = None , training : bool = False ) -> tuple [Tensor , Tensor ]:
7373 r"""Compute the drift and optionally the squared diffusion term for the reverse SDE.
7474 It can be derived from the derivative of the schedule:
75- \beta(t) = d/dt log(1 + e^(-snr(t)))
75+
76+ .. math::
77+ \beta(t) = d/dt \log(1 + e^{-snr(t)})
78+
7679 f(z, t) = -0.5 * \beta(t) * z
80+
7781 g(t)^2 = \beta(t)
7882
83+ The corresponding differential equations are::
84+
7985 SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW
8086 ODE: dz = [ f(z, t) - 0.5 * g(t)^2 * score(z, lambda) ] dt
8187
@@ -95,9 +101,11 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
95101 def get_alpha_sigma (self , log_snr_t : Tensor , training : bool ) -> tuple [Tensor , Tensor ]:
96102 """Get alpha and sigma for a given log signal-to-noise ratio (lambda).
97103
98- Default is a variance preserving schedule:
104+ Default is a variance preserving schedule::
105+
99106 alpha(t) = sqrt(sigmoid(log_snr_t))
100107 sigma(t) = sqrt(sigmoid(-log_snr_t))
108+
101109 For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda)
102110 """
103111 if self ._variance_type == "preserving" :
0 commit comments