|
1 | 1 | import math |
2 | | -from abc import ABC, abstractmethod |
3 | 2 | from typing import Union, Literal |
4 | 3 |
|
5 | 4 | from keras import ops |
6 | 5 |
|
7 | 6 | from bayesflow.types import Tensor |
8 | 7 | from bayesflow.utils.serialization import deserialize, serializable |
9 | 8 |
|
10 | | - |
11 | | -# disable module check, use potential module after moving from experimental |
12 | | -@serializable("bayesflow.networks", disable_module_check=True) |
13 | | -class NoiseSchedule(ABC): |
14 | | - r"""Noise schedule for diffusion models. We follow the notation from [1]. |
15 | | -
|
16 | | - The diffusion process is defined by a noise schedule, which determines how the noise level changes over time. |
17 | | - We define the noise schedule as a function of the log signal-to-noise ratio (lambda), which can be |
18 | | - interchangeably used with the diffusion time (t). |
19 | | -
|
20 | | - The noise process is defined as: z = alpha(t) * x + sigma(t) * e, where e ~ N(0, I). |
21 | | - The schedule is defined as: \lambda(t) = \log \sigma^2(t) - \log \alpha^2(t). |
22 | | -
|
23 | | - We can also define a weighting function for each noise level for the loss function. Often the noise schedule is |
24 | | - the same for the forward and reverse process, but this is not necessary and can be changed via the training flag. |
25 | | -
|
26 | | - [1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data |
27 | | - Augmentation: Kingma et al. (2023) |
28 | | - """ |
29 | | - |
30 | | - def __init__( |
31 | | - self, |
32 | | - name: str, |
33 | | - variance_type: Literal["preserving", "exploding"], |
34 | | - weighting: Literal["sigmoid", "likelihood_weighting"] = None, |
35 | | - ): |
36 | | - """ |
37 | | - Initialize the noise schedule. |
38 | | -
|
39 | | - Parameters |
40 | | - ---------- |
41 | | - name : str |
42 | | - The name of the noise schedule. |
43 | | - variance_type : Literal["preserving", "exploding"] |
44 | | - If the variance of noise added to the data should be preserved over time, use "preserving". |
45 | | - If the variance of noise added to the data should increase over time, use "exploding". |
46 | | - Default is "preserving". |
47 | | - weighting : Literal["sigmoid", "likelihood_weighting"], optional |
48 | | - The type of weighting function to use for the noise schedule. |
49 | | - Default is None, which means no weighting is applied. |
50 | | - """ |
51 | | - self.name = name |
52 | | - self._variance_type = variance_type |
53 | | - self.log_snr_min = None # should be set in the subclasses |
54 | | - self.log_snr_max = None # should be set in the subclasses |
55 | | - self._weighting = weighting |
56 | | - |
57 | | - @abstractmethod |
58 | | - def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor: |
59 | | - """Get the log signal-to-noise ratio (lambda) for a given diffusion time.""" |
60 | | - pass |
61 | | - |
62 | | - @abstractmethod |
63 | | - def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor: |
64 | | - """Get the diffusion time (t) from the log signal-to-noise ratio (lambda).""" |
65 | | - pass |
66 | | - |
67 | | - @abstractmethod |
68 | | - def derivative_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor: |
69 | | - r"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE.""" |
70 | | - pass |
71 | | - |
72 | | - def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = False) -> tuple[Tensor, Tensor]: |
73 | | - r"""Compute the drift and optionally the squared diffusion term for the reverse SDE. |
74 | | - It can be derived from the derivative of the schedule: |
75 | | -
|
76 | | - .. math:: |
77 | | - \beta(t) = d/dt \log(1 + e^{-snr(t)}) |
78 | | -
|
79 | | - f(z, t) = -0.5 * \beta(t) * z |
80 | | -
|
81 | | - g(t)^2 = \beta(t) |
82 | | -
|
83 | | - The corresponding differential equations are:: |
84 | | -
|
85 | | - SDE: d(z) = [ f(z, t) - g(t)^2 * score(z, lambda) ] dt + g(t) dW |
86 | | - ODE: dz = [ f(z, t) - 0.5 * g(t)^2 * score(z, lambda) ] dt |
87 | | -
|
88 | | - For a variance exploding schedule, one should set f(z, t) = 0. |
89 | | - """ |
90 | | - beta = self.derivative_log_snr(log_snr_t=log_snr_t, training=training) |
91 | | - if x is None: # return g^2 only |
92 | | - return beta |
93 | | - if self._variance_type == "preserving": |
94 | | - f = -0.5 * beta * x |
95 | | - elif self._variance_type == "exploding": |
96 | | - f = ops.zeros_like(beta) |
97 | | - else: |
98 | | - raise ValueError(f"Unknown variance type: {self._variance_type}") |
99 | | - return f, beta |
100 | | - |
101 | | - def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Tensor]: |
102 | | - """Get alpha and sigma for a given log signal-to-noise ratio (lambda). |
103 | | -
|
104 | | - Default is a variance preserving schedule:: |
105 | | -
|
106 | | - alpha(t) = sqrt(sigmoid(log_snr_t)) |
107 | | - sigma(t) = sqrt(sigmoid(-log_snr_t)) |
108 | | -
|
109 | | - For a variance exploding schedule, one should set alpha^2 = 1 and sigma^2 = exp(-lambda) |
110 | | - """ |
111 | | - if self._variance_type == "preserving": |
112 | | - # variance preserving schedule |
113 | | - alpha_t = ops.sqrt(ops.sigmoid(log_snr_t)) |
114 | | - sigma_t = ops.sqrt(ops.sigmoid(-log_snr_t)) |
115 | | - elif self._variance_type == "exploding": |
116 | | - # variance exploding schedule |
117 | | - alpha_t = ops.ones_like(log_snr_t) |
118 | | - sigma_t = ops.sqrt(ops.exp(-log_snr_t)) |
119 | | - else: |
120 | | - raise TypeError(f"Unknown variance type: {self._variance_type}") |
121 | | - return alpha_t, sigma_t |
122 | | - |
123 | | - def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor: |
124 | | - """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda). |
125 | | - Default weighting is None, which means only ones are returned. |
126 | | - Generally, weighting functions should be defined for a noise prediction loss. |
127 | | - """ |
128 | | - if self._weighting is None: |
129 | | - return ops.ones_like(log_snr_t) |
130 | | - elif self._weighting == "sigmoid": |
131 | | - # sigmoid weighting based on Kingma et al. (2023) |
132 | | - return ops.sigmoid(-log_snr_t + 2) |
133 | | - elif self._weighting == "likelihood_weighting": |
134 | | - # likelihood weighting based on Song et al. (2021) |
135 | | - g_squared = self.get_drift_diffusion(log_snr_t=log_snr_t) |
136 | | - sigma_t = self.get_alpha_sigma(log_snr_t=log_snr_t, training=True)[1] |
137 | | - return g_squared / ops.square(sigma_t) |
138 | | - else: |
139 | | - raise TypeError(f"Unknown weighting type: {self._weighting}") |
140 | | - |
141 | | - def get_config(self): |
142 | | - return dict(name=self.name, variance_type=self._variance_type) |
143 | | - |
144 | | - @classmethod |
145 | | - def from_config(cls, config, custom_objects=None): |
146 | | - return cls(**deserialize(config, custom_objects=custom_objects)) |
147 | | - |
148 | | - def validate(self): |
149 | | - """Validate the noise schedule.""" |
150 | | - if self.log_snr_min >= self.log_snr_max: |
151 | | - raise ValueError("min_log_snr must be less than max_log_snr.") |
152 | | - for training in [True, False]: |
153 | | - if not ops.isfinite(self.get_log_snr(0.0, training=training)): |
154 | | - raise ValueError(f"log_snr(0) must be finite with training={training}.") |
155 | | - if not ops.isfinite(self.get_log_snr(1.0, training=training)): |
156 | | - raise ValueError(f"log_snr(1) must be finite with training={training}.") |
157 | | - if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_max, training=training)): |
158 | | - raise ValueError(f"t(0) must be finite with training={training}.") |
159 | | - if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_min, training=training)): |
160 | | - raise ValueError(f"t(1) must be finite with training={training}.") |
161 | | - if not ops.isfinite(self.derivative_log_snr(self.log_snr_max, training=False)): |
162 | | - raise ValueError("dt/t log_snr(0) must be finite.") |
163 | | - if not ops.isfinite(self.derivative_log_snr(self.log_snr_min, training=False)): |
164 | | - raise ValueError("dt/t log_snr(1) must be finite.") |
| 9 | +from .noise_schedule_base import NoiseSchedule |
165 | 10 |
|
166 | 11 |
|
167 | 12 | @serializable("bayesflow.experimental") |
|
0 commit comments