Skip to content

Commit 9abb126

Browse files
authored
fix serialization in StableConsistencyModel (#578)
1 parent 08ed995 commit 9abb126

File tree

2 files changed

+81
-27
lines changed

2 files changed

+81
-27
lines changed

bayesflow/experimental/stable_consistency_model/stable_consistency_model.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,7 @@
55

66
from bayesflow.networks import MLP
77
from bayesflow.types import Tensor
8-
from bayesflow.utils import (
9-
logging,
10-
jvp,
11-
concatenate_valid,
12-
find_network,
13-
expand_right_as,
14-
expand_right_to,
15-
layer_kwargs,
16-
)
8+
from bayesflow.utils import logging, jvp, find_network, expand_right_as, expand_right_to, layer_kwargs, tensor_utils
179
from bayesflow.utils.serialization import deserialize, serializable, serialize
1810

1911
from bayesflow.networks import InferenceNetwork
@@ -83,6 +75,11 @@ def __init__(
8375
includes depth, hidden sizes, and non-linearity choices.
8476
embedding_kwargs : dict[str, any], optional, default=None
8577
Keyword arguments for the time embedding layer(s) used in the model
78+
concatenate_subnet_input: bool, optional
79+
Flag for advanced users to control whether all inputs to the subnet should be concatenated
80+
into a single vector or passed as separate arguments. If set to False, the subnet
81+
must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio),
82+
and optional 'conditions'. Default is True.
8683
**kwargs
8784
Additional keyword arguments passed to the parent ``InferenceNetwork`` initializer
8885
(e.g., ``name``, ``dtype``, or ``trainable``).
@@ -97,6 +94,7 @@ def __init__(
9794
self.subnet_projector = keras.layers.Dense(
9895
units=None, bias_initializer="zeros", kernel_initializer="zeros", name="subnet_projector"
9996
)
97+
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)
10098

10199
weight_mlp_kwargs = weight_mlp_kwargs or {}
102100
weight_mlp_kwargs = StableConsistencyModel.WEIGHT_MLP_DEFAULT_CONFIG | weight_mlp_kwargs
@@ -107,6 +105,7 @@ def __init__(
107105
)
108106

109107
embedding_kwargs = embedding_kwargs or {}
108+
self.embedding_kwargs = embedding_kwargs
110109
self.time_emb = FourierEmbedding(**embedding_kwargs)
111110
self.time_emb_dim = self.time_emb.embed_dim
112111

@@ -124,6 +123,8 @@ def get_config(self):
124123
config = {
125124
"subnet": self.subnet,
126125
"sigma": self.sigma,
126+
"embedding_kwargs": self.embedding_kwargs,
127+
"concatenate_subnet_input": self._concatenate_subnet_input,
127128
}
128129

129130
return base_config | serialize(config)
@@ -151,17 +152,22 @@ def build(self, xz_shape, conditions_shape=None):
151152
# construct input shape for subnet and subnet projector
152153
input_shape = list(xz_shape)
153154

154-
# time vector
155-
input_shape[-1] += self.time_emb_dim + 1
156-
157-
if conditions_shape is not None:
158-
input_shape[-1] += conditions_shape[-1]
159-
160-
input_shape = tuple(input_shape)
161-
162-
self.subnet.build(input_shape)
163-
164-
input_shape = self.subnet.compute_output_shape(input_shape)
155+
if self._concatenate_subnet_input:
156+
# construct time vector
157+
input_shape[-1] += self.time_emb_dim + 1
158+
if conditions_shape is not None:
159+
input_shape[-1] += conditions_shape[-1]
160+
input_shape = tuple(input_shape)
161+
162+
self.subnet.build(input_shape)
163+
input_shape = self.subnet.compute_output_shape(input_shape)
164+
else:
165+
# Multiple separate inputs
166+
time_shape = tuple(xz_shape[:-1]) + (self.time_emb_dim + 1,) # same batch/sequence dims, 1 feature
167+
self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape)
168+
input_shape = self.subnet.compute_output_shape(
169+
x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape
170+
)
165171
self.subnet_projector.build(input_shape)
166172

167173
# input shape for time embedding
@@ -173,6 +179,35 @@ def build(self, xz_shape, conditions_shape=None):
173179
input_shape = self.weight_fn.compute_output_shape(input_shape)
174180
self.weight_fn_projector.build(input_shape)
175181

182+
def _apply_subnet(
183+
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
184+
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
185+
"""
186+
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
187+
the time `t`, and optional conditions or by returning them separately.
188+
189+
Parameters
190+
----------
191+
x : Tensor
192+
The parameter tensor, typically of shape (..., D), but can vary.
193+
t : Tensor
194+
The time tensor, typically of shape (..., 1).
195+
conditions : Tensor, optional
196+
The optional conditioning tensor (e.g. parameters).
197+
training : bool, optional
198+
The training mode flag, which can be used to control behavior during training.
199+
200+
Returns
201+
-------
202+
Tensor
203+
The output tensor from the subnet.
204+
"""
205+
if self._concatenate_subnet_input:
206+
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
207+
return self.subnet(xtc, training=training)
208+
else:
209+
return self.subnet(x=x, t=t, conditions=conditions, training=training)
210+
176211
def _forward(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
177212
# Consistency Models only learn the direction from noise distribution
178213
# to target distribution, so we cannot implement this function.
@@ -218,7 +253,6 @@ def consistency_function(
218253
t: Tensor,
219254
conditions: Tensor = None,
220255
training: bool = False,
221-
**kwargs,
222256
) -> Tensor:
223257
"""Compute consistency function at time t.
224258
@@ -235,8 +269,8 @@ def consistency_function(
235269
**kwargs : dict, optional, default: {}
236270
Additional keyword arguments passed to the inner network.
237271
"""
238-
xtc = concatenate_valid([x / self.sigma, self.time_emb(t), conditions], axis=-1)
239-
f = self.subnet_projector(self.subnet(xtc, training=training, **kwargs))
272+
subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=training)
273+
f = self.subnet_projector(subnet_out)
240274
out = ops.cos(t) * x - ops.sin(t) * self.sigma * f
241275
return out
242276

@@ -273,7 +307,7 @@ def compute_metrics(
273307
r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here
274308

275309
def f_teacher(x, t):
276-
o = self.subnet(concatenate_valid([x, self.time_emb(t), conditions], axis=-1), training=stage == "training")
310+
o = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training")
277311
return self.subnet_projector(o)
278312

279313
primals = (xt / self.sigma, t)
@@ -287,8 +321,8 @@ def f_teacher(x, t):
287321
cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt)
288322

289323
# calculate output of the network
290-
xtc = concatenate_valid([xt / self.sigma, self.time_emb(t), conditions], axis=-1)
291-
student_out = self.subnet_projector(self.subnet(xtc, training=stage == "training"))
324+
subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training")
325+
student_out = self.subnet_projector(subnet_out)
292326

293327
# calculate the tangent
294328
g = -(ops.cos(t) ** 2) * (self.sigma * teacher_output - dxtdt) - r * ops.cos(t) * ops.sin(t) * (

bayesflow/networks/embeddings/fourier_embedding.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from keras import ops
55

66
from bayesflow.types import Tensor
7-
from bayesflow.utils.serialization import serializable
7+
from bayesflow.utils import layer_kwargs
8+
from bayesflow.utils.serialization import serializable, serialize, deserialize
89

910

1011
@serializable("bayesflow.networks")
@@ -47,6 +48,8 @@ def __init__(
4748
self.scale = scale
4849
self.embed_dim = embed_dim
4950
self.include_identity = include_identity
51+
self.initializer = initializer
52+
self.trainable = trainable
5053

5154
def call(self, t: Tensor) -> Tensor:
5255
"""Embeds the one-dimensional time scalar into a higher-dimensional Fourier embedding.
@@ -68,3 +71,20 @@ def call(self, t: Tensor) -> Tensor:
6871
else:
6972
emb = ops.concatenate([ops.sin(proj), ops.cos(proj)], axis=-1)
7073
return emb
74+
75+
def get_config(self):
76+
base_config = super().get_config()
77+
base_config = layer_kwargs(base_config)
78+
79+
config = {
80+
"embed_dim": self.embed_dim,
81+
"scale": self.scale,
82+
"initializer": self.initializer,
83+
"trainable": self.trainable,
84+
"include_identity": self.include_identity,
85+
}
86+
return base_config | serialize(config)
87+
88+
@classmethod
89+
def from_config(cls, config, custom_objects=None):
90+
return cls(**deserialize(config, custom_objects=custom_objects))

0 commit comments

Comments
 (0)