Skip to content

Commit dd24941

Browse files
committed
Merge branch 'standardize_in_approx' of https://github.com/bayesflow-org/bayesflow into standardize_in_approx
2 parents 5773d28 + caf0491 commit dd24941

File tree

4 files changed

+130
-43
lines changed

4 files changed

+130
-43
lines changed

bayesflow/networks/standardization/standardization.py

Lines changed: 61 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,51 +3,49 @@
33
import keras
44

55
from bayesflow.types import Tensor, Shape
6-
from bayesflow.utils.serialization import serialize, deserialize, serializable
6+
from bayesflow.utils.serialization import serializable
77
from bayesflow.utils import expand_left_as, layer_kwargs
88
from bayesflow.utils.tree import flatten_shape
99

1010

1111
@serializable("bayesflow.networks")
1212
class Standardization(keras.Layer):
13-
def __init__(self, momentum: float = 0.95, epsilon: float = 1e-6, **kwargs):
13+
def __init__(self, **kwargs):
1414
"""
1515
Initializes a Standardization layer that will keep track of the running mean and
1616
running standard deviation across a batch of potentially nested tensors.
1717
18+
The layer computes and stores running estimates of the mean and variance using a numerically
19+
stable online algorithm, allowing for consistent normalization during both training and inference,
20+
regardless of batch composition.
21+
1822
Parameters
1923
----------
20-
momentum : float, optional
21-
Momentum for the exponential moving average used to update the mean and
22-
standard deviation during training. Must be between 0 and 1.
23-
Default is 0.95.
24-
epsilon: float, optional
25-
Stability parameter to avoid division by zero.
24+
**kwargs
25+
Additional keyword arguments passed to the base Keras Layer.
26+
27+
Notes
28+
-----
2629
"""
2730
super().__init__(**layer_kwargs(kwargs))
2831

29-
self.momentum = momentum
30-
self.epsilon = epsilon
3132
self.moving_mean = None
32-
self.moving_std = None
33+
self.moving_m2 = None
34+
self.count = None
35+
36+
def moving_std(self, index: int) -> Tensor:
37+
return keras.ops.sqrt(self.moving_m2[index] / self.count)
3338

3439
def build(self, input_shape: Shape):
3540
flattened_shapes = flatten_shape(input_shape)
41+
3642
self.moving_mean = [
3743
self.add_weight(shape=(shape[-1],), initializer="zeros", trainable=False) for shape in flattened_shapes
3844
]
39-
self.moving_std = [
40-
self.add_weight(shape=(shape[-1],), initializer="ones", trainable=False) for shape in flattened_shapes
45+
self.moving_m2 = [
46+
self.add_weight(shape=(shape[-1],), initializer="zeros", trainable=False) for shape in flattened_shapes
4147
]
42-
43-
def get_config(self) -> dict:
44-
base_config = super().get_config()
45-
config = {"momentum": self.momentum, "epsilon": self.epsilon}
46-
return base_config | serialize(config)
47-
48-
@classmethod
49-
def from_config(cls, config, custom_objects=None):
50-
return cls(**deserialize(config, custom_objects=custom_objects))
48+
self.count = self.add_weight(shape=(), initializer="zeros", trainable=False)
5149

5250
def call(
5351
self,
@@ -80,23 +78,25 @@ def call(
8078
flattened = keras.tree.flatten(x)
8179
outputs, log_det_jacs = [], []
8280

83-
for i, val in enumerate(flattened):
81+
for idx, val in enumerate(flattened):
8482
if stage == "training":
85-
self._update_moments(val, i)
83+
self._update_moments(val, idx)
8684

87-
mean = expand_left_as(self.moving_mean[i], val)
88-
std = expand_left_as(self.moving_std[i], val)
85+
mean = expand_left_as(self.moving_mean[idx], val)
86+
std = expand_left_as(self.moving_std(idx), val)
8987

9088
if forward:
9189
out = (val - mean) / std
90+
# if the std is zero, out will become nan. As val - mean(val) = 0 if std(val) = 0,
91+
# we can just replace them with zeros.
92+
out = keras.ops.nan_to_num(out, nan=0.0)
9293
else:
9394
out = mean + std * val
9495

9596
outputs.append(out)
9697

9798
if log_det_jac:
9899
ldj = keras.ops.sum(keras.ops.log(keras.ops.abs(std)), axis=-1)
99-
# For convenience, tile to batch shape of val
100100
ldj = keras.ops.tile(ldj, keras.ops.shape(val)[:-1])
101101
log_det_jacs.append(-ldj if forward else ldj)
102102

@@ -108,9 +108,38 @@ def call(
108108
return outputs
109109

110110
def _update_moments(self, x: Tensor, index: int):
111-
mean = keras.ops.mean(x, axis=tuple(range(keras.ops.ndim(x) - 1)))
112-
std = keras.ops.std(x, axis=tuple(range(keras.ops.ndim(x) - 1)))
113-
std = keras.ops.maximum(std, self.epsilon)
111+
"""
112+
Incrementally updates the running mean and variance (M2) per feature using a numerically
113+
stable online algorithm.
114+
115+
Parameters
116+
----------
117+
x : Tensor
118+
Input tensor of shape (..., features), where all axes except the last are treated as batch/sample axes.
119+
The method computes batch-wise statistics by aggregating over all non-feature axes and updates the
120+
running totals (mean, M2, and sample count) accordingly.
121+
index : int
122+
The index of the corresponding running statistics to be updated.
123+
"""
124+
125+
reduce_axes = tuple(range(x.ndim - 1))
126+
batch_count = keras.ops.cast(keras.ops.shape(x)[0], self.count.dtype)
127+
128+
# Compute batch mean and M2 per feature
129+
batch_mean = keras.ops.mean(x, axis=reduce_axes)
130+
batch_m2 = keras.ops.sum((x - expand_left_as(batch_mean, x)) ** 2, axis=reduce_axes)
131+
132+
# Read current totals
133+
mean = self.moving_mean[index]
134+
m2 = self.moving_m2[index]
135+
count = self.count
136+
137+
total_count = count + batch_count
138+
delta = batch_mean - mean
139+
140+
new_mean = mean + delta * (batch_count / total_count)
141+
new_m2 = m2 + batch_m2 + (delta**2) * (count * batch_count / total_count)
114142

115-
self.moving_mean[index].assign(self.momentum * self.moving_mean[index] + (1.0 - self.momentum) * mean)
116-
self.moving_std[index].assign(self.momentum * self.moving_std[index] + (1.0 - self.momentum) * std)
143+
self.moving_mean[index].assign(new_mean)
144+
self.moving_m2[index].assign(new_m2)
145+
self.count.assign(total_count)

bayesflow/scores/multivariate_normal_score.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,17 @@ class MultivariateNormalScore(ParametricDistributionScore):
2626
For more information see :py:class:`ScoringRule`.
2727
"""
2828

29+
RANK: dict[str, int] = {"covariance": 2}
30+
"""
31+
The covariance matrix is a rank 2 tensor and as such the inverse of the standardization operation is
32+
33+
x = x' * sigma ^ 2
34+
35+
Accordingly, covariance is also included in :py:attr:`NO_SHIFT`.
36+
"""
37+
38+
NO_SHIFT: tuple[str] = ("covariance",)
39+
2940
def __init__(self, dim: int = None, links: dict = None, **kwargs):
3041
super().__init__(links=links, **kwargs)
3142

bayesflow/scores/scoring_rule.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ class ScoringRule:
2626
and covariance simultaneously.
2727
"""
2828

29-
NOT_TRANSFORMING_LIKE_VECTOR_WARNING = tuple()
29+
NOT_TRANSFORMING_LIKE_VECTOR_WARNING: tuple[str] = tuple()
3030
"""
31-
This variable contains names of prediction heads that should lead to a warning when the adapter is applied
32-
in inverse direction to them.
31+
Names of prediction heads for which to warn if the adapter is called on their estimates in inverse direction.
3332
3433
Prediction heads can output estimates in spaces other than the target distribution space.
3534
To such estimates the adapter cannot be straightforwardly applied in inverse direction,
@@ -38,6 +37,33 @@ class ScoringRule:
3837
with a type of estimate whenever the adapter is applied to them in inverse direction.
3938
"""
4039

40+
RANK: dict[str, int] = {}
41+
"""
42+
Mapping of prediction head names to their tensor rank for inverse standardization.
43+
44+
The rank indicates the power to which the standard deviation is raised before being multiplied to some estimate
45+
in standardized space.
46+
47+
x = x' * sigma ^ rank [ + mean ]
48+
49+
If a head is not present in this mapping, a default rank of 1 is assumed.
50+
51+
Typically, if :py:attr:`RANK` is modified for an estimate, it is also included in :py:attr:`NO_SHIFT`.
52+
"""
53+
54+
NO_SHIFT: tuple[str] = tuple()
55+
"""
56+
Names of prediction heads whose estimates should not be shifted when applying inverse standardization.
57+
58+
During inverse standardization, point estimates are typically shifted by the stored mean vector. Any head
59+
listed in this tuple will skip the shift step and only be scaled. By default, this tuple is empty,
60+
meaning all heads will be shifted to undo standardization.
61+
62+
x = x' * sigma ^ rank + mean
63+
64+
See also :py:attr:`RANK`.
65+
"""
66+
4167
def __init__(
4268
self,
4369
subnets: dict[str, str | type] = None,

tests/test_networks/test_standardization.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,40 @@
1010
def test_forward_standardization_training():
1111
random_input = keras.random.normal((8, 4))
1212

13-
layer = Standardization(momentum=0.0) # no EMA for test stability
13+
layer = Standardization()
1414
layer.build(random_input.shape)
1515

1616
out = layer(random_input, stage="training")
1717

1818
moving_mean = keras.ops.convert_to_numpy(layer.moving_mean[0])
19-
moving_std = keras.ops.convert_to_numpy(layer.moving_std[0])
2019
random_input = keras.ops.convert_to_numpy(random_input)
2120
out = keras.ops.convert_to_numpy(out)
2221

2322
np.testing.assert_allclose(moving_mean, np.mean(random_input, axis=0), atol=1e-5)
24-
np.testing.assert_allclose(moving_std, np.std(random_input, axis=0), atol=1e-5)
2523

2624
assert out.shape == random_input.shape
2725
assert not np.any(np.isnan(out))
26+
np.testing.assert_allclose(np.std(out, axis=0), 1.0, atol=1e-5)
27+
28+
29+
def test_forward_standardization_training_constant_batch():
30+
constant_input = keras.ops.ones((8, 4))
31+
32+
layer = Standardization()
33+
layer.build(constant_input.shape)
34+
35+
out = layer(constant_input, stage="training")
36+
37+
moving_mean = keras.ops.convert_to_numpy(layer.moving_mean[0])
38+
constant_input = keras.ops.convert_to_numpy(constant_input)
39+
out = keras.ops.convert_to_numpy(out)
40+
41+
np.testing.assert_allclose(moving_mean, np.mean(constant_input, axis=0), atol=1e-5)
42+
43+
assert out.shape == constant_input.shape
44+
assert not np.any(np.isnan(out))
45+
np.testing.assert_allclose(out, 0.0, atol=1e-5)
46+
np.testing.assert_allclose(np.std(out, axis=0), 0.0, atol=1e-5)
2847

2948

3049
def test_inverse_standardization_ldj():
@@ -42,9 +61,10 @@ def test_inverse_standardization_ldj():
4261

4362
def test_consistency_forward_inverse():
4463
random_input = keras.random.normal((4, 20, 5))
45-
layer = Standardization(momentum=0.0)
46-
layer.build((5,))
47-
standardized = layer(random_input, stage="training", forward=True)
64+
layer = Standardization()
65+
_ = layer(random_input, stage="training", forward=True)
66+
67+
standardized = layer(random_input, stage="inference", forward=True)
4868
recovered = layer(standardized, stage="inference", forward=False)
4969

5070
random_input = keras.ops.convert_to_numpy(random_input)
@@ -58,9 +78,10 @@ def test_nested_consistency_forward_inverse():
5878
random_input_b = keras.random.normal((4, 3))
5979
random_input = {"a": random_input_a, "b": random_input_b}
6080

61-
layer = Standardization(momentum=0.0)
81+
layer = Standardization()
6282

63-
standardized = layer(random_input, stage="training", forward=True)
83+
_ = layer(random_input, stage="training", forward=True)
84+
standardized = layer(random_input, stage="inference", forward=True)
6485
recovered = layer(standardized, stage="inference", forward=False)
6586

6687
random_input = keras.tree.map_structure(keras.ops.convert_to_numpy, random_input)

0 commit comments

Comments
 (0)