Skip to content

Commit 41f4b08

Browse files
committed
Adapt sCMs [skip ci]
1 parent a0d8d5b commit 41f4b08

File tree

5 files changed

+2069
-732
lines changed

5 files changed

+2069
-732
lines changed

bayesflow/experimental/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
from .cif import CIF
6-
from .continuous_time_consistency_model import ContinuousTimeConsistencyModel
6+
from .stable_consistency_model import StableConsistencyModel
77
from .diffusion_model import DiffusionModel
88
from .free_form_flow import FreeFormFlow
99

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .stable_consistency_model import StableConsistencyModel

bayesflow/experimental/continuous_time_consistency_model.py renamed to bayesflow/experimental/stable_consistency_model/stable_consistency_model.py

Lines changed: 84 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,107 @@
1+
from math import pi
2+
13
import keras
24
from keras import ops
35

4-
import numpy as np
5-
66
from bayesflow.networks import MLP
77
from bayesflow.types import Tensor
88
from bayesflow.utils import (
9+
logging,
910
jvp,
1011
concatenate_valid,
1112
find_network,
1213
expand_right_as,
1314
expand_right_to,
14-
model_kwargs,
15+
layer_kwargs,
1516
)
1617
from bayesflow.utils.serialization import deserialize, serializable, serialize
1718

18-
1919
from bayesflow.networks import InferenceNetwork
2020
from bayesflow.networks.embeddings import FourierEmbedding
2121

2222

2323
# disable module check, use potential module after moving from experimental
2424
@serializable("bayesflow.networks", disable_module_check=True)
25-
class ContinuousTimeConsistencyModel(InferenceNetwork):
26-
"""(IN) Implements an sCM (simple, stable, and scalable Consistency Model)
27-
with continous-time Consistency Training (CT) as described in [1].
28-
The sampling procedure is taken from [2].
25+
class StableConsistencyModel(InferenceNetwork):
26+
"""(IN) Implements an sCM (simple, stable, and scalable Consistency Model) with continuous-time Consistency Training
27+
(CT) as described in [1]. The sampling procedure is taken from [2].
2928
3029
[1] Lu, C., & Song, Y. (2024).
3130
Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models
3231
arXiv preprint arXiv:2410.11081
3332
3433
[2] Song, Y., Dhariwal, P., Chen, M. & Sutskever, I. (2023).
35-
Consistency Models.
36-
arXiv preprint arXiv:2303.01469
34+
Consistency Models. arXiv preprint arXiv:2303.01469
3735
"""
3836

37+
MLP_DEFAULT_CONFIG = {
38+
"widths": (256, 256, 256, 256, 256),
39+
"activation": "mish",
40+
"kernel_initializer": "he_normal",
41+
"residual": True,
42+
"dropout": 0.05,
43+
"spectral_normalization": False,
44+
}
45+
46+
WEIGHT_MLP_DEFAULT_CONFIG = {
47+
"widths": (256,),
48+
"activation": "mish",
49+
"kernel_initializer": "he_normal",
50+
"residual": False,
51+
"dropout": 0.05,
52+
"spectral_normalization": False,
53+
}
54+
55+
EPS_WARN = 0.1
56+
3957
def __init__(
4058
self,
41-
subnet: str | keras.Layer = "mlp",
42-
sigma_data: float = 1.0,
59+
subnet: str | type | keras.Layer = "mlp",
60+
sigma: float = 1.0,
4361
subnet_kwargs: dict[str, any] = None,
62+
weight_mlp_kwargs: dict[str, any] = None,
4463
embedding_kwargs: dict[str, any] = None,
4564
**kwargs,
4665
):
4766
"""Creates an instance of an sCM to be used for consistency training (CT).
4867
4968
Parameters
5069
----------
51-
subnet : str or type, optional, default: "mlp"
52-
A neural network type for the consistency model, will be
53-
instantiated using subnet_kwargs.
54-
sigma_data : float, optional, default: 1.0
55-
Standard deviation of the target distribution
70+
subnet : str, type, or keras.Layer, optional, default="mlp"
71+
The neural network architecture used for the consistency model.
72+
If a string is provided, it should be a registered name (e.g., "mlp").
73+
If a type or keras.Layer is provided, it will be directly instantiated
74+
with the given ``subnet_kwargs``.
75+
sigma : float, optional, default=1.0
76+
Standard deviation of the target distribution for the consistency loss.
77+
Controls the scale of the noise injected during training.
78+
subnet_kwargs : dict[str, any], optional, default=None
79+
Keyword arguments passed to the constructor of the chosen ``subnet``. For example, number of hidden units,
80+
activation functions, or dropout settings.
81+
weight_mlp_kwargs : dict[str, any], optional, default=None
82+
Keyword arguments for an auxiliary MLP used to generate weights within the consistency model. Typically
83+
includes depth, hidden sizes, and non-linearity choices.
84+
embedding_kwargs : dict[str, any], optional, default=None
85+
Keyword arguments for the time embedding layer(s) used in the model
5686
**kwargs
57-
Additional keyword arguments to the layer.
87+
Additional keyword arguments passed to the parent ``InferenceNetwork`` initializer
88+
(e.g., ``name``, ``dtype``, or ``trainable``).
5889
"""
5990
super().__init__(base_distribution="normal", **kwargs)
6091

6192
subnet_kwargs = subnet_kwargs or {}
62-
93+
if subnet == "mlp":
94+
subnet_kwargs = StableConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs
6395
self.subnet = find_network(subnet, **subnet_kwargs)
96+
6497
self.subnet_projector = keras.layers.Dense(
6598
units=None, bias_initializer="zeros", kernel_initializer="zeros", name="subnet_projector"
6699
)
67100

68-
self.weight_fn = MLP([256], dropout=0.0)
101+
weight_mlp_kwargs = weight_mlp_kwargs or {}
102+
weight_mlp_kwargs = StableConsistencyModel.WEIGHT_MLP_DEFAULT_CONFIG | weight_mlp_kwargs
103+
self.weight_fn = MLP(**weight_mlp_kwargs)
104+
69105
self.weight_fn_projector = keras.layers.Dense(
70106
units=1, bias_initializer="zeros", kernel_initializer="zeros", name="weight_fn_projector"
71107
)
@@ -74,8 +110,7 @@ def __init__(
74110
self.time_emb = FourierEmbedding(**embedding_kwargs)
75111
self.time_emb_dim = self.time_emb.embed_dim
76112

77-
self.sigma_data = sigma_data
78-
113+
self.sigma = sigma
79114
self.seed_generator = keras.random.SeedGenerator()
80115

81116
@classmethod
@@ -84,29 +119,33 @@ def from_config(cls, config, custom_objects=None):
84119

85120
def get_config(self):
86121
base_config = super().get_config()
87-
base_config = model_kwargs(base_config)
122+
base_config = layer_kwargs(base_config)
88123

89124
config = {
90125
"subnet": self.subnet,
91-
"sigma_data": self.sigma_data,
126+
"sigma": self.sigma,
92127
}
93128

94129
return base_config | serialize(config)
95130

96131
def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs):
97-
t = np.linspace(0.0, np.pi / 2, num_steps)
98-
times = np.exp((t - np.pi / 2) * rho) * np.pi / 2
99-
times[0] = 0.0
132+
t = keras.ops.linspace(0.0, pi / 2, num_steps)
133+
times = keras.ops.exp((t - pi / 2) * rho) * pi / 2
134+
times.at[0].set(0.0)
100135

101136
# if rho is set too low, bad schedules can occur
102-
EPS_WARN = 0.1
103-
if times[1] > EPS_WARN:
104-
print("Warning: The last time step is large.")
105-
print(f"Increasing rho (was {rho}) or n_steps (was {num_steps}) might improve results.")
106-
return ops.convert_to_tensor(times)
137+
if times[1] > StableConsistencyModel.EPS_WARN:
138+
logging.warning("Warning: The last time step is large.")
139+
logging.warning(f"Increasing rho (was {rho}) or n_steps (was {num_steps}) might improve results.")
140+
return times
107141

108142
def build(self, xz_shape, conditions_shape=None):
109-
super().build(xz_shape)
143+
if self.built:
144+
# building when the network is already built can cause issues with serialization
145+
# see https://github.com/keras-team/keras/issues/21147
146+
return
147+
148+
self.base_distribution.build(xz_shape)
110149
self.subnet_projector.units = xz_shape[-1]
111150

112151
# construct input shape for subnet and subnet projector
@@ -134,17 +173,6 @@ def build(self, xz_shape, conditions_shape=None):
134173
input_shape = self.weight_fn.compute_output_shape(input_shape)
135174
self.weight_fn_projector.build(input_shape)
136175

137-
def call(
138-
self,
139-
xz: Tensor,
140-
conditions: Tensor = None,
141-
inverse: bool = False,
142-
**kwargs,
143-
):
144-
if inverse:
145-
return self._inverse(xz, conditions=conditions, **kwargs)
146-
return self._forward(xz, conditions=conditions, **kwargs)
147-
148176
def _forward(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
149177
# Consistency Models only learn the direction from noise distribution
150178
# to target distribution, so we cannot implement this function.
@@ -172,8 +200,8 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
172200
steps = kwargs.get("steps", 15)
173201
rho = kwargs.get("rho", 3.5)
174202

175-
# noise distribution has variance sigma_data
176-
x = keras.ops.copy(z) * self.sigma_data
203+
# noise distribution has variance sigma
204+
x = keras.ops.copy(z) * self.sigma
177205
discretized_time = keras.ops.flip(self._discretize_time(steps, rho=rho), axis=-1)
178206
t = keras.ops.full((*keras.ops.shape(x)[:-1], 1), discretized_time[0], dtype=x.dtype)
179207
x = self.consistency_function(x, t, conditions=conditions)
@@ -207,9 +235,9 @@ def consistency_function(
207235
**kwargs : dict, optional, default: {}
208236
Additional keyword arguments passed to the inner network.
209237
"""
210-
xtc = concatenate_valid([x / self.sigma_data, self.time_emb(t), conditions], axis=-1)
238+
xtc = concatenate_valid([x / self.sigma, self.time_emb(t), conditions], axis=-1)
211239
f = self.subnet_projector(self.subnet(xtc, training=training, **kwargs))
212-
out = ops.cos(t) * x - ops.sin(t) * self.sigma_data * f
240+
out = ops.cos(t) * x - ops.sin(t) * self.sigma * f
213241
return out
214242

215243
def compute_metrics(
@@ -226,17 +254,14 @@ def compute_metrics(
226254
c = 0.1
227255

228256
# generate noise vector
229-
z = (
230-
keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)
231-
* self.sigma_data
232-
)
257+
z = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator) * self.sigma
233258

234259
# sample time
235260
tau = (
236261
keras.random.normal(keras.ops.shape(x)[:1], dtype=keras.ops.dtype(x), seed=self.seed_generator) * p_std
237262
+ p_mean
238263
)
239-
t_ = ops.arctan(ops.exp(tau) / self.sigma_data)
264+
t_ = ops.arctan(ops.exp(tau) / self.sigma)
240265
t = expand_right_as(t_, x)
241266

242267
# generate noisy sample
@@ -251,23 +276,23 @@ def f_teacher(x, t):
251276
o = self.subnet(concatenate_valid([x, self.time_emb(t), conditions], axis=-1), training=stage == "training")
252277
return self.subnet_projector(o)
253278

254-
primals = (xt / self.sigma_data, t)
279+
primals = (xt / self.sigma, t)
255280
tangents = (
256281
ops.cos(t) * ops.sin(t) * dxtdt,
257-
ops.cos(t) * ops.sin(t) * self.sigma_data,
282+
ops.cos(t) * ops.sin(t) * self.sigma,
258283
)
259284

260285
teacher_output, cos_sin_dFdt = jvp(f_teacher, primals, tangents, return_output=True)
261286
teacher_output = ops.stop_gradient(teacher_output)
262287
cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt)
263288

264289
# calculate output of the network
265-
xtc = concatenate_valid([xt / self.sigma_data, self.time_emb(t), conditions], axis=-1)
290+
xtc = concatenate_valid([xt / self.sigma, self.time_emb(t), conditions], axis=-1)
266291
student_out = self.subnet_projector(self.subnet(xtc, training=stage == "training"))
267292

268293
# calculate the tangent
269-
g = -(ops.cos(t) ** 2) * (self.sigma_data * teacher_output - dxtdt) - r * ops.cos(t) * ops.sin(t) * (
270-
xt + self.sigma_data * cos_sin_dFdt
294+
g = -(ops.cos(t) ** 2) * (self.sigma * teacher_output - dxtdt) - r * ops.cos(t) * ops.sin(t) * (
295+
xt + self.sigma * cos_sin_dFdt
271296
)
272297

273298
# apply normalization to stabilize training
@@ -277,6 +302,7 @@ def f_teacher(x, t):
277302
w = self.weight_fn_projector(self.weight_fn(expand_right_to(t_, 2)))
278303

279304
D = ops.shape(x)[-1]
305+
280306
loss = ops.mean(
281307
(ops.exp(w) / D)
282308
* ops.mean(

0 commit comments

Comments
 (0)