Skip to content

Commit d79b17a

Browse files
committed
Fix init bugs, adapt tnotebooks
1 parent ceab303 commit d79b17a

File tree

4 files changed

+368
-149
lines changed

4 files changed

+368
-149
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
self.inference_network = inference_network
5050
self.summary_network = summary_network
5151
self.standardize = standardize
52+
5253
self.inference_variables_norm = None
5354
self.summary_variables_norm = None
5455
self.inference_conditions_norm = None
@@ -59,7 +60,6 @@ def build_adapter(
5960
inference_variables: Sequence[str],
6061
inference_conditions: Sequence[str] = None,
6162
summary_variables: Sequence[str] = None,
62-
standardize: bool = True,
6363
sample_weight: str = None,
6464
) -> Adapter:
6565
"""Create an :py:class:`~bayesflow.adapters.Adapter` suited for the approximator.
@@ -72,8 +72,6 @@ def build_adapter(
7272
Names of the inference conditions in the data
7373
summary_variables : Sequence of str, optional
7474
Names of the summary variables in the data
75-
standardize : bool, optional
76-
Decide whether to standardize all variables, default is True
7775
sample_weight : str, optional
7876
Name of the sample weights
7977
"""
@@ -95,9 +93,6 @@ def build_adapter(
9593

9694
adapter.keep(["inference_variables", "inference_conditions", "summary_variables", "sample_weight"])
9795

98-
if standardize:
99-
adapter.standardize(exclude="sample_weight")
100-
10196
return adapter
10297

10398
def compile(
@@ -118,7 +113,7 @@ def compile(
118113

119114
return super().compile(*args, **kwargs)
120115

121-
def build_from_data(self, adapted_data: dict[str, any]) -> None:
116+
def build_from_data(self, adapted_data: dict[str, any]):
122117
# Determine input standardization
123118
if self.standardize == "all":
124119
keys = ["inference_variables", "summary_variables", "inference_conditions"]
@@ -129,13 +124,15 @@ def build_from_data(self, adapted_data: dict[str, any]) -> None:
129124
else:
130125
keys = []
131126

132-
if "inference_variables" in keys:
127+
if "inference_variables" in adapted_data and "inference_variables" in keys:
133128
self.inference_variables_norm = Standardization()
134129
self.inference_variables_norm(adapted_data["inference_variables"])
135-
if "summary_variables" in keys and self.summary_network:
130+
131+
if "summary_variables" in adapted_data and "summary_variables" in keys and self.summary_network:
136132
self.summary_variables_norm = Standardization()
137133
self.summary_variables_norm(adapted_data["summary_variables"])
138-
if "inference_conditions" in keys:
134+
135+
if "inference_conditions" in adapted_data and "inference_conditions" in keys:
139136
self.inference_conditions_norm = Standardization()
140137
self.inference_conditions_norm(adapted_data["inference_conditions"])
141138

@@ -394,21 +391,18 @@ def sample(
394391

395392
# Optionally standardize conditions
396393
if "summary_variables" in conditions and self.summary_variables_norm:
397-
conditions["summary_variables"] = self.summary_variables_norm(
398-
conditions["summary_variables"], stage="inference"
399-
)
394+
conditions["summary_variables"] = self.summary_variables_norm(conditions["summary_variables"])
400395

401396
if "inference_conditions" in conditions and self.inference_conditions_norm:
402-
conditions["inference_conditions"] = self.inference_conditions_norm(
403-
conditions["inference_conditions"], stage="inference"
404-
)
397+
conditions["inference_conditions"] = self.inference_conditions_norm(conditions["inference_conditions"])
398+
405399
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
406400

407401
# Sample and undo optional standardization
408402
samples = self._sample(num_samples=num_samples, **conditions, **kwargs)
409403

410404
if self.inference_variables_norm:
411-
samples = self.inference_variables_norm(samples, stage="inference", forward=False)
405+
samples = self.inference_variables_norm(samples, forward=False)
412406

413407
samples = {"inference_variables": samples}
414408
samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples)
@@ -512,16 +506,14 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dic
512506

513507
# Optionally standardize conditions and variables
514508
if "summary_variables" in data and self.summary_variables_norm:
515-
data["summary_variables"] = self.summary_variables_norm(data["summary_variables"], stage="inference")
509+
data["summary_variables"] = self.summary_variables_norm(data["summary_variables"])
516510

517511
if "inference_conditions" in data and self.inference_conditions_norm:
518-
data["inference_conditions"] = self.inference_conditions_norm(
519-
data["inference_conditions"], stage="inference"
520-
)
512+
data["inference_conditions"] = self.inference_conditions_norm(data["inference_conditions"])
521513

522514
if self.inference_variables_norm:
523515
data["inference_variables"], log_det_jac = self.summary_variables_norm(
524-
data["inference_variables"], stage="inference", log_det_jac=True
516+
data["inference_variables"], log_det_jac=True
525517
)
526518
log_det_jac = keras.ops.convert_to_numpy(log_det_jac)
527519
else:

bayesflow/networks/standardization/standardization.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import keras
44

55
from bayesflow.types import Tensor, Shape
6-
from bayesflow.utils.serialization import serialize, serializable
6+
from bayesflow.utils.serialization import serialize, deserialize, serializable
77
from bayesflow.utils import expand_left_as
88

99

1010
@serializable("bayesflow.networks")
1111
class Standardization(keras.Layer):
12-
def __init__(self, momentum: float = 0.99):
12+
def __init__(self, momentum: float = 0.95, epsilon: float = 1e-6):
1313
"""
1414
Initializes a Standardization layer that will keep track of the running mean and
1515
running standard deviation across a batch of tensors.
@@ -19,27 +19,28 @@ def __init__(self, momentum: float = 0.99):
1919
momentum : float, optional
2020
Momentum for the exponential moving average used to update the mean and
2121
standard deviation during training. Must be between 0 and 1.
22-
Default is 0.99.
22+
Default is 0.95.
23+
epsilon: float, optional
24+
Stability parameter to avoid division by zero.
2325
"""
2426
super().__init__()
2527

2628
self.momentum = momentum
29+
self.epsilon = epsilon
2730
self.moving_mean = None
2831
self.moving_std = None
2932

30-
def build(self, input_shape: Shape, **kwargs):
31-
self.moving_mean = self.add_weight(shape=(input_shape[-1],), initializer="ones", name="scale", trainable=False)
32-
self.moving_std = self.add_weight(shape=(input_shape[-1],), initializer="zeros", name="bias", trainable=False)
33+
def build(self, input_shape: Shape):
34+
self.moving_mean = self.add_weight(shape=(input_shape[-1],), initializer="zeros", trainable=False)
35+
self.moving_std = self.add_weight(shape=(input_shape[-1],), initializer="ones", trainable=False)
3336

3437
def get_config(self) -> dict:
35-
config = {"momentum": self.momentum}
38+
config = {"momentum": self.momentum, "epsilon": self.epsilon}
3639
return serialize(config)
3740

38-
def _update_moments(self, x: Tensor):
39-
mean = keras.ops.mean(x, axis=list(range(keras.ops.ndim(x)))[:-1])
40-
std = keras.ops.std(x, axis=list(range(keras.ops.ndim(x)))[:-1])
41-
self.moving_mean.assign(self.momentum * self.moving_mean + (1.0 - self.momentum) * mean)
42-
self.moving_std.assign(self.momentum * self.moving_std + (1.0 - self.momentum) * std)
41+
@classmethod
42+
def from_config(cls, config, custom_objects=None):
43+
return cls(**deserialize(config, custom_objects=custom_objects))
4344

4445
def call(
4546
self, x: Tensor, stage: str = "inference", forward: bool = True, log_det_jac: bool = False, **kwargs
@@ -53,7 +54,7 @@ def call(
5354
Input tensor of shape (..., dim).
5455
stage : str, optional
5556
Indicates the stage of computation. If "training", the running statistics
56-
are updated. Default is "training".
57+
are updated. Default is "inference".
5758
forward : bool, optional
5859
If True, apply standardization: (x - mean) / std.
5960
If False, apply inverse transformation: x * std + mean and return the log-determinant
@@ -84,3 +85,11 @@ def call(
8485
return x, ldj
8586

8687
return x
88+
89+
def _update_moments(self, x: Tensor):
90+
mean = keras.ops.mean(x, axis=tuple(range(keras.ops.ndim(x) - 1)))
91+
std = keras.ops.std(x, axis=tuple(range(keras.ops.ndim(x) - 1)))
92+
std = keras.ops.maximum(std, self.epsilon)
93+
94+
self.moving_mean.assign(self.momentum * self.moving_mean + (1.0 - self.momentum) * mean)
95+
self.moving_std.assign(self.momentum * self.moving_std + (1.0 - self.momentum) * std)

examples/Linear_Regression_Starter.ipynb

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@
382382
},
383383
{
384384
"cell_type": "code",
385-
"execution_count": 13,
385+
"execution_count": null,
386386
"metadata": {
387387
"ExecuteTime": {
388388
"end_time": "2025-02-14T10:51:30.684080Z",
@@ -401,7 +401,6 @@
401401
" .broadcast(\"N\", to=\"x\")\n",
402402
" .as_set([\"x\", \"y\"])\n",
403403
" .constrain(\"sigma\", lower=0)\n",
404-
" .standardize(exclude=[\"N\"])\n",
405404
" .sqrt(\"N\")\n",
406405
" .convert_dtype(\"float64\", \"float32\")\n",
407406
" .concatenate([\"beta\", \"sigma\"], into=\"inference_variables\")\n",
@@ -424,9 +423,6 @@
424423
"\n",
425424
"The `.constrain(\"sigma\", lower=0)` transform ensures that the residual standard deviation parameter `sigma` will always be positive. Without this constrain, the neural networks may attempt to predict negative `sigma` which of course would not make much sense.\n",
426425
"\n",
427-
"Standardidazation via `.standardize()` is important for neural networks to learn\n",
428-
"reliably without, for example, exploding or vanishing gradients during training. However, we need to exclude the variable `N` from standardization, via `standardize(exclude=[\"N\"])`. This is because `N` is a constant within each batch of training data and can hence not be standardized. In the future, bayesflow will automatically detect this case so that we don't have to manually exclude such constant variables from standardization.\n",
429-
"\n",
430426
"Let's check the shape of our processed data to be passed to the neural networks:"
431427
]
432428
},
@@ -1028,7 +1024,8 @@
10281024
"name": "python3"
10291025
},
10301026
"language_info": {
1031-
"name": "python"
1027+
"name": "python",
1028+
"version": "3.11.12"
10321029
},
10331030
"widgets": {
10341031
"application/vnd.jupyter.widget-state+json": {

examples/SIR_Posterior_Estimation.ipynb

Lines changed: 329 additions & 108 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)