Skip to content

Commit ceab303

Browse files
committed
Add standardization to continuous approximator and test
1 parent 4781e2e commit ceab303

File tree

5 files changed

+232
-20
lines changed

5 files changed

+232
-20
lines changed

bayesflow/approximators/approximator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from collections.abc import Mapping
2-
31
import multiprocessing as mp
42

53
import keras
@@ -22,8 +20,8 @@ def build_adapter(cls, **kwargs) -> Adapter:
2220
# implemented by each respective architecture
2321
raise NotImplementedError
2422

25-
def build_from_data(self, data: Mapping[str, any]) -> None:
26-
self.compute_metrics(**filter_kwargs(data, self.compute_metrics), stage="training")
23+
def build_from_data(self, adapted_data: dict[str, any]) -> None:
24+
self.compute_metrics(**filter_kwargs(adapted_data, self.compute_metrics), stage="training")
2725
self.built = True
2826

2927
@classmethod

bayesflow/approximators/continuous_approximator.py

Lines changed: 92 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from bayesflow.utils.serialization import serialize, deserialize, serializable
1212

1313
from .approximator import Approximator
14+
from ..networks.standardization import Standardization
1415

1516

1617
@serializable("bayesflow.approximators")
@@ -40,12 +41,17 @@ def __init__(
4041
adapter: Adapter,
4142
inference_network: InferenceNetwork,
4243
summary_network: SummaryNetwork = None,
44+
standardize: str | Sequence[str] | None = "all",
4345
**kwargs,
4446
):
4547
super().__init__(**kwargs)
4648
self.adapter = adapter
4749
self.inference_network = inference_network
4850
self.summary_network = summary_network
51+
self.standardize = standardize
52+
self.inference_variables_norm = None
53+
self.summary_variables_norm = None
54+
self.inference_conditions_norm = None
4955

5056
@classmethod
5157
def build_adapter(
@@ -112,6 +118,31 @@ def compile(
112118

113119
return super().compile(*args, **kwargs)
114120

121+
def build_from_data(self, adapted_data: dict[str, any]) -> None:
122+
# Determine input standardization
123+
if self.standardize == "all":
124+
keys = ["inference_variables", "summary_variables", "inference_conditions"]
125+
elif isinstance(self.standardize, str):
126+
keys = [self.standardize]
127+
elif isinstance(self.standardize, Sequence):
128+
keys = self.standardize
129+
else:
130+
keys = []
131+
132+
if "inference_variables" in keys:
133+
self.inference_variables_norm = Standardization()
134+
self.inference_variables_norm(adapted_data["inference_variables"])
135+
if "summary_variables" in keys and self.summary_network:
136+
self.summary_variables_norm = Standardization()
137+
self.summary_variables_norm(adapted_data["summary_variables"])
138+
if "inference_conditions" in keys:
139+
self.inference_conditions_norm = Standardization()
140+
self.inference_conditions_norm(adapted_data["inference_conditions"])
141+
142+
# Call compute metrics once to build inner networks
143+
self.compute_metrics(**filter_kwargs(adapted_data, self.compute_metrics), stage="training")
144+
self.built = True
145+
115146
def compile_from_config(self, config):
116147
self.compile(**deserialize(config))
117148
if hasattr(self, "optimizer") and self.built:
@@ -126,6 +157,10 @@ def compute_metrics(
126157
sample_weight: Tensor = None,
127158
stage: str = "training",
128159
) -> dict[str, Tensor]:
160+
# Optionally standardize optional inference conditions
161+
if inference_conditions and self.inference_conditions_norm:
162+
inference_conditions = self.inference_conditions_norm(inference_conditions, stage=stage)
163+
129164
if self.summary_network is None:
130165
if summary_variables is not None:
131166
raise ValueError("Cannot compute summary metrics without a summary network.")
@@ -135,6 +170,9 @@ def compute_metrics(
135170
if summary_variables is None:
136171
raise ValueError("Summary variables are required when a summary network is present.")
137172

173+
if self.summary_variables_norm is not None:
174+
summary_variables = self.summary_variables_norm(summary_variables, stage=stage)
175+
138176
summary_metrics = self.summary_network.compute_metrics(summary_variables, stage=stage)
139177
summary_outputs = summary_metrics.pop("outputs")
140178

@@ -146,6 +184,10 @@ def compute_metrics(
146184

147185
# Force a conversion to Tensor
148186
inference_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, inference_variables)
187+
188+
if self.inference_variables_norm is not None:
189+
inference_variables = self.inference_variables_norm(inference_variables, stage=stage)
190+
149191
inference_metrics = self.inference_network.compute_metrics(
150192
inference_variables, conditions=inference_conditions, sample_weight=sample_weight, stage=stage
151193
)
@@ -223,6 +265,7 @@ def get_config(self):
223265
"adapter": self.adapter,
224266
"inference_network": self.inference_network,
225267
"summary_network": self.summary_network,
268+
"standardize": self.standardize,
226269
}
227270

228271
return base_config | serialize(config)
@@ -349,16 +392,33 @@ def sample(
349392
# Ensure only keys relevant for sampling are present in the conditions dictionary
350393
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.SAMPLE_KEYS}
351394

395+
# Optionally standardize conditions
396+
if "summary_variables" in conditions and self.summary_variables_norm:
397+
conditions["summary_variables"] = self.summary_variables_norm(
398+
conditions["summary_variables"], stage="inference"
399+
)
400+
401+
if "inference_conditions" in conditions and self.inference_conditions_norm:
402+
conditions["inference_conditions"] = self.inference_conditions_norm(
403+
conditions["inference_conditions"], stage="inference"
404+
)
352405
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
353-
conditions = {"inference_variables": self._sample(num_samples=num_samples, **conditions, **kwargs)}
354-
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
406+
407+
# Sample and undo optional standardization
408+
samples = self._sample(num_samples=num_samples, **conditions, **kwargs)
409+
410+
if self.inference_variables_norm:
411+
samples = self.inference_variables_norm(samples, stage="inference", forward=False)
412+
413+
samples = {"inference_variables": samples}
414+
samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples)
355415

356416
# Back-transform quantities and samples
357-
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
417+
samples = self.adapter(samples, inverse=True, strict=False, **kwargs)
358418

359419
if split:
360-
conditions = split_arrays(conditions, axis=-1)
361-
return conditions
420+
samples = split_arrays(samples, axis=-1)
421+
return samples
362422

363423
def _sample(
364424
self,
@@ -400,37 +460,35 @@ def _sample(
400460
**filter_kwargs(kwargs, self.inference_network.sample),
401461
)
402462

403-
def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
463+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
404464
"""
405-
Computes the summaries of given data.
465+
Computes the learned summary statistics of given inputs.
406466
407467
The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.
408468
409469
Parameters
410470
----------
411471
data : Mapping[str, np.ndarray]
412-
Dictionary of data as NumPy arrays.
472+
Dictionary of simulated or real quantities as NumPy arrays.
413473
**kwargs : dict
414474
Additional keyword arguments for the adapter and the summary network.
415475
416476
Returns
417477
-------
418478
summaries : np.ndarray
419-
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
420-
421-
Raises
422-
------
423-
ValueError
424-
If the approximator does not have a summary network, or the adapter does not produce the output required
425-
by the summary network.
479+
The learned summary statistics.
426480
"""
427481
if self.summary_network is None:
428-
raise ValueError("A summary network is required to compute summeries.")
482+
raise ValueError("A summary network is required to compute summaries.")
483+
429484
data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
430485
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
431486
raise ValueError("Summary variables are required to compute summaries.")
487+
432488
summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
433489
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
490+
summaries = keras.ops.convert_to_numpy(summaries)
491+
434492
return summaries
435493

436494
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]:
@@ -451,6 +509,24 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic
451509
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
452510
"""
453511
data, log_det_jac = self.adapter(data, strict=False, stage="inference", log_det_jac=True, **kwargs)
512+
513+
# Optionally standardize conditions and variables
514+
if "summary_variables" in data and self.summary_variables_norm:
515+
data["summary_variables"] = self.summary_variables_norm(data["summary_variables"], stage="inference")
516+
517+
if "inference_conditions" in data and self.inference_conditions_norm:
518+
data["inference_conditions"] = self.inference_conditions_norm(
519+
data["inference_conditions"], stage="inference"
520+
)
521+
522+
if self.inference_variables_norm:
523+
data["inference_variables"], log_det_jac = self.summary_variables_norm(
524+
data["inference_variables"], stage="inference", log_det_jac=True
525+
)
526+
log_det_jac = keras.ops.convert_to_numpy(log_det_jac)
527+
else:
528+
log_det_jac = 0.0
529+
454530
data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
455531
log_prob = self._log_prob(**data, **kwargs)
456532
log_prob = keras.tree.map_structure(keras.ops.convert_to_numpy, log_prob)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .standardization import Standardization
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from collections.abc import Sequence
2+
3+
import keras
4+
5+
from bayesflow.types import Tensor, Shape
6+
from bayesflow.utils.serialization import serialize, serializable
7+
from bayesflow.utils import expand_left_as
8+
9+
10+
@serializable("bayesflow.networks")
11+
class Standardization(keras.Layer):
12+
def __init__(self, momentum: float = 0.99):
13+
"""
14+
Initializes a Standardization layer that will keep track of the running mean and
15+
running standard deviation across a batch of tensors.
16+
17+
Parameters
18+
----------
19+
momentum : float, optional
20+
Momentum for the exponential moving average used to update the mean and
21+
standard deviation during training. Must be between 0 and 1.
22+
Default is 0.99.
23+
"""
24+
super().__init__()
25+
26+
self.momentum = momentum
27+
self.moving_mean = None
28+
self.moving_std = None
29+
30+
def build(self, input_shape: Shape, **kwargs):
31+
self.moving_mean = self.add_weight(shape=(input_shape[-1],), initializer="ones", name="scale", trainable=False)
32+
self.moving_std = self.add_weight(shape=(input_shape[-1],), initializer="zeros", name="bias", trainable=False)
33+
34+
def get_config(self) -> dict:
35+
config = {"momentum": self.momentum}
36+
return serialize(config)
37+
38+
def _update_moments(self, x: Tensor):
39+
mean = keras.ops.mean(x, axis=list(range(keras.ops.ndim(x)))[:-1])
40+
std = keras.ops.std(x, axis=list(range(keras.ops.ndim(x)))[:-1])
41+
self.moving_mean.assign(self.momentum * self.moving_mean + (1.0 - self.momentum) * mean)
42+
self.moving_std.assign(self.momentum * self.moving_std + (1.0 - self.momentum) * std)
43+
44+
def call(
45+
self, x: Tensor, stage: str = "inference", forward: bool = True, log_det_jac: bool = False, **kwargs
46+
) -> Tensor | Sequence[Tensor]:
47+
"""
48+
Apply standardization or its inverse to the input tensor, optionally compute the log det of the Jacobian.
49+
50+
Parameters
51+
----------
52+
x : Tensor
53+
Input tensor of shape (..., dim).
54+
stage : str, optional
55+
Indicates the stage of computation. If "training", the running statistics
56+
are updated. Default is "training".
57+
forward : bool, optional
58+
If True, apply standardization: (x - mean) / std.
59+
If False, apply inverse transformation: x * std + mean and return the log-determinant
60+
of the Jacobian. Default is True.
61+
log_det_jac: bool, optional
62+
Whether to return the log determinant of the transformation. Default is False.
63+
64+
Returns
65+
-------
66+
Tensor or Sequence[Tensor]
67+
If `forward` is True, returns the standardized tensor, otherwise un-standardizes.
68+
If `log_det_jec` is True, returns a tuple: (transformed tensor, log-determinant) otherwise just
69+
transformed tensor.
70+
"""
71+
if stage == "training":
72+
self._update_moments(x)
73+
74+
if forward:
75+
x = (x - expand_left_as(self.moving_mean, x)) / expand_left_as(self.moving_std, x)
76+
else:
77+
x = expand_left_as(self.moving_mean, x) + expand_left_as(self.moving_std, x) * x
78+
79+
if log_det_jac:
80+
ldj = keras.ops.sum(keras.ops.log(keras.ops.abs(self.moving_std)), axis=-1)
81+
ldj = keras.ops.broadcast_to(ldj, keras.ops.shape(x)[:-1])
82+
if forward:
83+
ldj = -ldj
84+
return x, ldj
85+
86+
return x
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
import keras
3+
4+
from bayesflow.networks.standardization import Standardization
5+
6+
7+
def test_forward_standardization_training():
8+
random_input = keras.random.normal((8, 4))
9+
10+
layer = Standardization(momentum=0.0) # no EMA for test stability
11+
layer.build(random_input.shape)
12+
13+
out = layer(random_input, stage="training", forward=True)
14+
15+
moving_mean = keras.ops.convert_to_numpy(layer.moving_mean)
16+
moving_std = keras.ops.convert_to_numpy(layer.moving_std)
17+
random_input = keras.ops.convert_to_numpy(random_input)
18+
out = keras.ops.convert_to_numpy(out)
19+
20+
# mean should now match the batch input
21+
np.testing.assert_allclose(moving_mean, np.mean(random_input, axis=0), atol=1e-5)
22+
np.testing.assert_allclose(moving_std, np.std(random_input, axis=0), atol=1e-5)
23+
24+
assert out.shape == random_input.shape
25+
assert not np.any(np.isnan(out))
26+
27+
28+
def test_inverse_standardization_ldj():
29+
random_input = keras.random.normal((1, 3))
30+
31+
layer = Standardization(momentum=0.0)
32+
layer.build(random_input.shape)
33+
34+
_ = layer(random_input, stage="training", forward=True) # trigger moment update
35+
inv_x, ldj = layer(random_input, stage="inference", forward=False)
36+
37+
assert inv_x.shape == random_input.shape
38+
assert ldj.shape == random_input.shape[:-1]
39+
40+
41+
def test_consistency_forward_inverse():
42+
random_input = keras.random.normal((4, 20, 5))
43+
layer = Standardization(momentum=0.0)
44+
layer.build((5,))
45+
standardized = layer(random_input, stage="training", forward=True)
46+
recovered, _ = layer(standardized, stage="inference", forward=False)
47+
48+
random_input = keras.ops.convert_to_numpy(random_input)
49+
recovered = keras.ops.convert_to_numpy(recovered)
50+
51+
np.testing.assert_allclose(random_input, recovered, atol=1e-4)

0 commit comments

Comments
 (0)