Skip to content

Commit 549a055

Browse files
committed
Add diffusion model implementation, EDM variant
Preliminary implementation, to be extended with other variants as well.
1 parent 2bf0b53 commit 549a055

File tree

3 files changed

+368
-2
lines changed

3 files changed

+368
-2
lines changed

bayesflow/experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from .cif import CIF
66
from .continuous_time_consistency_model import ContinuousTimeConsistencyModel
7+
from .diffusion_model import DiffusionModel
78
from .free_form_flow import FreeFormFlow
89

910
from ..utils._docs import _add_imports_to_all
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
from collections.abc import Sequence
2+
import keras
3+
from keras import ops
4+
from keras.saving import register_keras_serializable as serializable
5+
6+
from bayesflow.types import Tensor, Shape
7+
import bayesflow as bf
8+
from bayesflow.networks import InferenceNetwork
9+
10+
from bayesflow.utils import (
11+
expand_right_as,
12+
find_network,
13+
jacobian_trace,
14+
keras_kwargs,
15+
serialize_value_or_type,
16+
deserialize_value_or_type,
17+
weighted_mean,
18+
integrate,
19+
)
20+
21+
22+
@serializable(package="bayesflow.networks")
23+
class DiffusionModel(InferenceNetwork):
24+
"""Diffusion Model as described as Elucidated Diffusion Model in [1].
25+
26+
[1] Elucidating the Design Space of Diffusion-Based Generative Models: arXiv:2206.00364
27+
"""
28+
29+
MLP_DEFAULT_CONFIG = {
30+
"widths": (256, 256, 256, 256, 256),
31+
"activation": "mish",
32+
"kernel_initializer": "he_normal",
33+
"residual": True,
34+
"dropout": 0.0,
35+
"spectral_normalization": False,
36+
}
37+
38+
INTEGRATE_DEFAULT_CONFIG = {
39+
"method": "euler",
40+
"steps": 100,
41+
}
42+
43+
def __init__(
44+
self,
45+
subnet: str | type = "mlp",
46+
integrate_kwargs: dict[str, any] = None,
47+
subnet_kwargs: dict[str, any] = None,
48+
sigma_data=1.0,
49+
**kwargs,
50+
):
51+
"""
52+
Initializes a diffusion model with configurable subnet architecture.
53+
54+
This model learns a transformation from a Gaussian latent distribution to a target distribution using a
55+
specified subnet type, which can be an MLP or a custom network.
56+
57+
The integration steps can be customized with additional parameters available in the respective
58+
configuration dictionary.
59+
60+
Parameters
61+
----------
62+
subnet : str or type, optional
63+
The architecture used for the transformation network. Can be "mlp" or a custom
64+
callable network. Default is "mlp".
65+
integrate_kwargs : dict[str, any], optional
66+
Additional keyword arguments for the integration process. Default is None.
67+
subnet_kwargs : dict[str, any], optional
68+
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
69+
sigma_data : float, optional
70+
Averaged standard deviation of the target distribution. Default is 1.0.
71+
**kwargs
72+
Additional keyword arguments passed to the subnet and other components.
73+
"""
74+
75+
super().__init__(base_distribution=None, **keras_kwargs(kwargs))
76+
77+
# internal tunable parameters not intended to be modified by the average user
78+
self.max_sigma = kwargs.get("max_sigma", 80.0)
79+
self.min_sigma = kwargs.get("min_sigma", 1e-4)
80+
self.rho = kwargs.get("rho", 7)
81+
# hyper-parameters for sampling the noise level
82+
self.p_mean = kwargs.get("p_mean", -1.2)
83+
self.p_std = kwargs.get("p_std", 1.2)
84+
85+
# latent distribution (not configurable)
86+
self.base_distribution = bf.distributions.DiagonalNormal(mean=0.0, std=self.max_sigma)
87+
self.integrate_kwargs = self.INTEGRATE_DEFAULT_CONFIG | (integrate_kwargs or {})
88+
89+
self.sigma_data = sigma_data
90+
91+
self.seed_generator = keras.random.SeedGenerator()
92+
93+
subnet_kwargs = subnet_kwargs or {}
94+
if subnet == "mlp":
95+
subnet_kwargs = self.MLP_DEFAULT_CONFIG | subnet_kwargs
96+
97+
self.subnet = find_network(subnet, **subnet_kwargs)
98+
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros")
99+
100+
# serialization: store all parameters necessary to call __init__
101+
self.config = {
102+
"integrate_kwargs": self.integrate_kwargs,
103+
"subnet_kwargs": subnet_kwargs,
104+
"sigma_data": sigma_data,
105+
**kwargs,
106+
}
107+
self.config = serialize_value_or_type(self.config, "subnet", subnet)
108+
109+
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
110+
super().build(xz_shape, conditions_shape=conditions_shape)
111+
112+
self.output_projector.units = xz_shape[-1]
113+
input_shape = list(xz_shape)
114+
115+
# construct time vector
116+
input_shape[-1] += 1
117+
if conditions_shape is not None:
118+
input_shape[-1] += conditions_shape[-1]
119+
120+
input_shape = tuple(input_shape)
121+
122+
self.subnet.build(input_shape)
123+
out_shape = self.subnet.compute_output_shape(input_shape)
124+
self.output_projector.build(out_shape)
125+
126+
def get_config(self):
127+
base_config = super().get_config()
128+
return base_config | self.config
129+
130+
@classmethod
131+
def from_config(cls, config):
132+
config = deserialize_value_or_type(config, "subnet")
133+
return cls(**config)
134+
135+
def _c_skip_fn(self, sigma):
136+
return self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
137+
138+
def _c_out_fn(self, sigma):
139+
return sigma * self.sigma_data / ops.sqrt(self.sigma_data**2 + sigma**2)
140+
141+
def _c_in_fn(self, sigma):
142+
return 1.0 / ops.sqrt(sigma**2 + self.sigma_data**2)
143+
144+
def _c_noise_fn(self, sigma):
145+
return 0.25 * ops.log(sigma)
146+
147+
def _denoiser_fn(
148+
self,
149+
xz: Tensor,
150+
sigma: Tensor,
151+
conditions: Tensor = None,
152+
training: bool = False,
153+
):
154+
# calculate output of the network
155+
c_in = self._c_in_fn(sigma)
156+
c_noise = self._c_noise_fn(sigma)
157+
xz_pre = c_in * xz
158+
if conditions is None:
159+
xtc = keras.ops.concatenate([xz_pre, c_noise], axis=-1)
160+
else:
161+
xtc = keras.ops.concatenate([xz_pre, c_noise, conditions], axis=-1)
162+
out = self.output_projector(self.subnet(xtc, training=training), training=training)
163+
return self._c_skip_fn(sigma) * xz + self._c_out_fn(sigma) * out
164+
165+
def velocity(
166+
self,
167+
xz: Tensor,
168+
sigma: float | Tensor,
169+
conditions: Tensor = None,
170+
training: bool = False,
171+
) -> Tensor:
172+
# transform sigma vector into correct shape
173+
sigma = keras.ops.convert_to_tensor(sigma, dtype=keras.ops.dtype(xz))
174+
sigma = expand_right_as(sigma, xz)
175+
sigma = keras.ops.broadcast_to(sigma, keras.ops.shape(xz)[:-1] + (1,))
176+
177+
d = self._denoiser_fn(xz, sigma, conditions, training=training)
178+
return (xz - d) / sigma
179+
180+
def _velocity_trace(
181+
self,
182+
xz: Tensor,
183+
sigma: Tensor,
184+
conditions: Tensor = None,
185+
max_steps: int = None,
186+
training: bool = False,
187+
) -> (Tensor, Tensor):
188+
def f(x):
189+
return self.velocity(x, sigma=sigma, conditions=conditions, training=training)
190+
191+
v, trace = jacobian_trace(f, xz, max_steps=max_steps, seed=self.seed_generator, return_output=True)
192+
193+
return v, keras.ops.expand_dims(trace, axis=-1)
194+
195+
def _forward(
196+
self,
197+
x: Tensor,
198+
conditions: Tensor = None,
199+
density: bool = False,
200+
training: bool = False,
201+
**kwargs,
202+
) -> Tensor | tuple[Tensor, Tensor]:
203+
integrate_kwargs = self.integrate_kwargs | kwargs
204+
if isinstance(integrate_kwargs["steps"], int):
205+
# set schedule for specified number of steps
206+
integrate_kwargs["steps"] = self._integration_schedule(integrate_kwargs["steps"], dtype=ops.dtype(x))
207+
if density:
208+
209+
def deltas(time, xz):
210+
v, trace = self._velocity_trace(xz, sigma=time, conditions=conditions, training=training)
211+
return {"xz": v, "trace": trace}
212+
213+
state = {
214+
"xz": x,
215+
"trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x)),
216+
}
217+
state = integrate(
218+
deltas,
219+
state,
220+
**integrate_kwargs,
221+
)
222+
223+
z = state["xz"]
224+
log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1)
225+
226+
return z, log_density
227+
228+
def deltas(time, xz):
229+
return {"xz": self.velocity(xz, sigma=time, conditions=conditions, training=training)}
230+
231+
state = {"xz": x}
232+
state = integrate(
233+
deltas,
234+
state,
235+
**integrate_kwargs,
236+
)
237+
238+
z = state["xz"]
239+
240+
return z
241+
242+
def _inverse(
243+
self,
244+
z: Tensor,
245+
conditions: Tensor = None,
246+
density: bool = False,
247+
training: bool = False,
248+
**kwargs,
249+
) -> Tensor | tuple[Tensor, Tensor]:
250+
integrate_kwargs = self.integrate_kwargs | kwargs
251+
if isinstance(integrate_kwargs["steps"], int):
252+
# set schedule for specified number of steps
253+
integrate_kwargs["steps"] = self._integration_schedule(
254+
integrate_kwargs["steps"], inverse=True, dtype=ops.dtype(z)
255+
)
256+
if density:
257+
258+
def deltas(time, xz):
259+
v, trace = self._velocity_trace(xz, sigma=time, conditions=conditions, training=training)
260+
return {"xz": v, "trace": trace}
261+
262+
state = {
263+
"xz": z,
264+
"trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z)),
265+
}
266+
state = integrate(deltas, state, **integrate_kwargs)
267+
268+
x = state["xz"]
269+
log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1)
270+
271+
return x, log_density
272+
273+
def deltas(time, xz):
274+
return {"xz": self.velocity(xz, sigma=time, conditions=conditions, training=training)}
275+
276+
state = {"xz": z}
277+
state = integrate(
278+
deltas,
279+
state,
280+
**integrate_kwargs,
281+
)
282+
283+
x = state["xz"]
284+
285+
return x
286+
287+
def compute_metrics(
288+
self,
289+
x: Tensor | Sequence[Tensor, ...],
290+
conditions: Tensor = None,
291+
sample_weight: Tensor = None,
292+
stage: str = "training",
293+
) -> dict[str, Tensor]:
294+
training = stage == "training"
295+
if not self.built:
296+
xz_shape = keras.ops.shape(x)
297+
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
298+
self.build(xz_shape, conditions_shape)
299+
300+
# sample log-noise level
301+
log_sigma = self.p_mean + self.p_std * keras.random.normal(
302+
ops.shape(x)[:1], dtype=ops.dtype(x), seed=self.seed_generator
303+
)
304+
# noise level with shape (batch_size, 1)
305+
sigma = ops.exp(log_sigma)[:, None]
306+
307+
# generate noise vector
308+
z = sigma * keras.random.normal(ops.shape(x), dtype=ops.dtype(x), seed=self.seed_generator)
309+
310+
# calculate preconditioning
311+
c_skip = self._c_skip_fn(sigma)
312+
c_out = self._c_out_fn(sigma)
313+
c_in = self._c_in_fn(sigma)
314+
c_noise = self._c_noise_fn(sigma)
315+
xz_pre = c_in * (x + z)
316+
317+
# calculate output of the network
318+
if conditions is None:
319+
xtc = keras.ops.concatenate([xz_pre, c_noise], axis=-1)
320+
else:
321+
xtc = keras.ops.concatenate([xz_pre, c_noise, conditions], axis=-1)
322+
323+
out = self.output_projector(self.subnet(xtc, training=training), training=training)
324+
325+
# Calculate loss:
326+
lam = 1 / c_out[:, 0] ** 2
327+
effective_weight = lam * c_out[:, 0] ** 2
328+
unweighted_loss = ops.mean((out - 1 / c_out * (x - c_skip * (x + z))) ** 2, axis=-1)
329+
loss = effective_weight * unweighted_loss
330+
loss = weighted_mean(loss, sample_weight)
331+
332+
base_metrics = super().compute_metrics(x, conditions, sample_weight, stage)
333+
return base_metrics | {"loss": loss}
334+
335+
def _integration_schedule(self, steps, inverse=False, dtype=None):
336+
def sigma_i(i, steps):
337+
N = steps + 1
338+
return (
339+
self.max_sigma ** (1 / self.rho)
340+
+ (i / (N - 1)) * (self.min_sigma ** (1 / self.rho) - self.max_sigma ** (1 / self.rho))
341+
) ** self.rho
342+
343+
steps = sigma_i(ops.arange(steps + 1, dtype=dtype), steps)
344+
if not inverse:
345+
steps = ops.flip(steps)
346+
return steps

0 commit comments

Comments
 (0)