Skip to content

Commit 5773d28

Browse files
committed
Merge branch 'standardize_in_approx' of https://github.com/bayesflow-org/bayesflow into standardize_in_approx
2 parents 0c24db2 + b2bfeea commit 5773d28

File tree

10 files changed

+152
-54
lines changed

10 files changed

+152
-54
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ complex to be described analytically.
5151

5252
## Install
5353

54-
You can install the latest stable version from PyPI using:
54+
We currently support Python 3.10 to 3.12. You can install the latest stable version from PyPI using:
5555

5656
```bash
57-
pip install bayesflow
57+
pip install "bayesflow>=2.0"
5858
```
5959

6060
If you want the latest features, you can install from source:
@@ -134,7 +134,8 @@ Many examples from [Bayesian Cognitive Modeling: A Practical Course](https://bay
134134
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
135135
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
136136
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)
137-
8. [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb)
137+
8. [Likelihood estimation](examples/Likelihood_Estimation.ipynb)
138+
9. [Moving from BayesFlow v1.1 to v2.0](examples/From_BayesFlow_1.1_to_2.0.ipynb)
138139

139140
More tutorials are always welcome! Please consider making a pull request if you have a cool application that you want to contribute.
140141

bayesflow/approximators/continuous_approximator.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -409,18 +409,10 @@ def sample(
409409
dict[str, np.ndarray]
410410
Dictionary containing generated samples with the same keys as `conditions`.
411411
"""
412-
413-
# Apply adapter transforms to raw simulated / real quantities
414-
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
415-
416-
# Ensure only keys relevant for sampling are present in the conditions dictionary
412+
# Adapt, optionally standardize and convert conditions to tensor.
413+
conditions = self._prepare_data(conditions, **kwargs)
414+
# Remove any superfluous keys, just retain actual conditions. # TODO: is this necessary?
417415
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.CONDITION_KEYS}
418-
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
419-
420-
# Optionally standardize conditions
421-
for key in ContinuousApproximator.CONDITION_KEYS:
422-
if key in conditions and key in self.standardize:
423-
conditions[key] = self.standardize_layers[key](conditions[key])
424416

425417
# Sample and undo optional standardization
426418
samples = self._sample(num_samples=num_samples, **conditions, **kwargs)
@@ -438,6 +430,51 @@ def sample(
438430
samples = split_arrays(samples, axis=-1)
439431
return samples
440432

433+
def _prepare_data(
434+
self, data: Mapping[str, np.ndarray], log_det_jac: bool = False, **kwargs
435+
) -> dict[str, Tensor] | tuple[dict[str, Tensor], dict[str, Tensor]]:
436+
"""
437+
Adapts, optionally standardizes, and converts the data to tensors to prepare it for the inference network.
438+
439+
Deals with data that represents only conditions, or only inference_variables or both.
440+
"""
441+
# TODO:
442+
# * [ ] better docstring
443+
444+
# Adapt, and optionally keep track of ldj of transformations to inference_variables.
445+
adapted = self.adapter(data, strict=False, stage="inference", log_det_jac=log_det_jac, **kwargs)
446+
if log_det_jac:
447+
data, log_det_jac_adapter = adapted
448+
log_det_jac_inference_variables = log_det_jac_adapter.get("inference_variables", 0.0)
449+
else:
450+
data = adapted
451+
452+
# Optionally standardize conditions, if they are part of data.
453+
conditions = {k: v for k, v in data.items() if k in ContinuousApproximator.CONDITION_KEYS}
454+
for key, value in conditions.items():
455+
if key in self.standardize and key in data.keys():
456+
data[key] = self.standardize_layers[key](value)
457+
458+
# Optionally standardize inference variables, if they are part of data.
459+
if "inference_variables" in data.keys() and "inference_variables" in self.standardize:
460+
standardized = self.standardize_layers["inference_variables"](
461+
data["inference_variables"], log_det_jac=log_det_jac
462+
)
463+
464+
# Optionally keep track of appropriate log_det_jac.
465+
if log_det_jac:
466+
data["inference_variables"], log_det_std = standardized
467+
log_det_jac_inference_variables += keras.ops.convert_to_numpy(log_det_std)
468+
else:
469+
data["inference_variables"] = standardized
470+
471+
# Convert to tensor and return.
472+
data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
473+
if log_det_jac:
474+
return data, log_det_jac
475+
else:
476+
return data
477+
441478
def _sample(
442479
self,
443480
num_samples: int,
@@ -517,24 +554,14 @@ def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
517554
np.ndarray
518555
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
519556
"""
520-
data, log_det_jac = self.adapter(data, strict=False, stage="inference", log_det_jac=True, **kwargs)
521-
log_det_jac = log_det_jac.get("inference_variables", 0.0)
522-
523-
# Optionally standardize conditions
524-
for key in ContinuousApproximator.CONDITION_KEYS:
525-
if key in data and key in self.standardize:
526-
data[key] = self.standardize_layers[key](data[key])
557+
# Adapt, optionally standardize and convert to tensor. Keep track of log_det_jac.
558+
data, log_det_jac = self._prepare_data(data, log_det_jac=True, **kwargs)
527559

528-
# Optionally standardize inference variables
529-
if "inference_variables" in self.standardize:
530-
data["inference_variables"], log_det_std = self.standardize_layers["inference_variables"](
531-
data["inference_variables"], log_det_jac=True
532-
)
533-
log_det_jac += keras.ops.convert_to_numpy(log_det_std)
534-
535-
data = keras.tree.map_structure(keras.ops.convert_to_tensor, data)
560+
# Pass data to networks and convert back to numpy array.
536561
log_prob = self._log_prob(**data, **kwargs)
537562
log_prob = keras.ops.convert_to_numpy(log_prob)
563+
564+
# Change of variables formula.
538565
log_prob = log_prob + log_det_jac
539566

540567
return log_prob

bayesflow/approximators/point_approximator.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ def estimate(
5555
Each estimator output (i.e., dictionary value that is not itself a dictionary) is an array
5656
of shape (num_datasets, point_estimate_size, variable_block_size).
5757
"""
58-
59-
conditions = self._prepare_conditions(conditions, **kwargs)
58+
# Adapt, optionally standardize and convert conditions to tensor.
59+
conditions = self._prepare_data(conditions, **kwargs)
60+
# Remove any superfluous keys, just retain actual conditions. # TODO: is this necessary?
61+
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.CONDITION_KEYS}
6062

6163
estimates = self._estimate(**conditions, **kwargs)
6264
estimates = self._apply_inverse_adapter_to_estimates(estimates, **kwargs)
@@ -110,9 +112,19 @@ def sample(
110112
Each output (i.e., dictionary value that is not itself a dictionary) is an array
111113
of shape (num_datasets, num_samples, variable_block_size).
112114
"""
113-
conditions = self._prepare_conditions(conditions, **kwargs)
115+
# Adapt, optionally standardize and convert conditions to tensor.
116+
conditions = self._prepare_data(conditions, **kwargs)
117+
# Remove any superfluous keys, just retain actual conditions. # TODO: is this necessary?
118+
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.CONDITION_KEYS}
114119

120+
# Sample and undo optional standardization
115121
samples = self._sample(num_samples, **conditions, **kwargs)
122+
123+
if "inference_variables" in self.standardize:
124+
for score_key in samples.keys():
125+
samples[score_key] = self.standardize_layers["inference_variables"](samples[score_key], forward=False)
126+
127+
samples = {"inference_variables": samples}
116128
samples = self._apply_inverse_adapter_to_samples(samples, **kwargs)
117129

118130
if split:
@@ -152,20 +164,20 @@ def log_prob(
152164
153165
Log-probabilities have shape (num_datasets,).
154166
"""
155-
return super().log_prob(data=data, **kwargs)
167+
# Adapt, optionally standardize and convert to tensor. Keep track of log_det_jac
168+
data, log_det_jac = self._prepare_data(data, log_det_jac=True, **kwargs)
156169

157-
def _prepare_conditions(self, conditions: Mapping[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
158-
"""Adapts, optionally standardizes, and converts the conditions to tensors."""
170+
# Pass data to networks and convert back to numpy array
171+
log_prob = self._log_prob(**data, **kwargs)
172+
log_prob = keras.tree.map_structure(keras.ops.convert_to_numpy, log_prob)
159173

160-
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
161-
conditions = {k: v for k, v in conditions.items() if k in ContinuousApproximator.CONDITION_KEYS}
174+
# Change of variables formula, respecting log_prob to be a dictionary
175+
if log_det_jac is not None:
176+
log_prob = keras.tree.map_structure(lambda x: x + log_det_jac, log_prob)
162177

163-
# Optionally standardize conditions
164-
for key, value in conditions.items():
165-
if key in self.standardize:
166-
conditions[key] = self.standardize_layers[key](value)
178+
log_prob = PointApproximator._squeeze_parametric_score_major_dict(log_prob)
167179

168-
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
180+
return log_prob
169181

170182
def _apply_inverse_adapter_to_estimates(
171183
self, estimates: Mapping[str, Mapping[str, Tensor]], **kwargs

bayesflow/datasets/disk_dataset.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
adapter: Adapter | None,
3838
stage: str = "training",
3939
augmentations: Mapping[str, Callable] | Callable = None,
40+
shuffle: bool = True,
4041
**kwargs,
4142
):
4243
"""
@@ -67,6 +68,8 @@ def __init__(
6768
6869
Note - augmentations are applied before the adapter is called and are generally
6970
transforms that you only want to apply during training.
71+
shuffle : bool, optional
72+
Whether to shuffle the dataset at initialization and at the end of each epoch. Default is True.
7073
**kwargs
7174
Additional keyword arguments passed to the base `PyDataset`.
7275
"""
@@ -79,8 +82,9 @@ def __init__(
7982
self.stage = stage
8083

8184
self.augmentations = augmentations
82-
83-
self.shuffle()
85+
self._shuffle = shuffle
86+
if self._shuffle:
87+
self.shuffle()
8488

8589
def __getitem__(self, item) -> dict[str, np.ndarray]:
8690
if not 0 <= item < self.num_batches:
@@ -108,7 +112,8 @@ def __getitem__(self, item) -> dict[str, np.ndarray]:
108112
return batch
109113

110114
def on_epoch_end(self):
111-
self.shuffle()
115+
if self._shuffle:
116+
self.shuffle()
112117

113118
@property
114119
def num_batches(self):

bayesflow/datasets/offline_dataset.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
*,
2525
stage: str = "training",
2626
augmentations: Mapping[str, Callable] | Callable = None,
27+
shuffle: bool = True,
2728
**kwargs,
2829
):
2930
"""
@@ -51,6 +52,8 @@ def __init__(
5152
5253
Note - augmentations are applied before the adapter is called and are generally
5354
transforms that you only want to apply during training.
55+
shuffle : bool, optional
56+
Whether to shuffle the dataset at initialization and at the end of each epoch. Default is True.
5457
**kwargs
5558
Additional keyword arguments passed to the base `PyDataset`.
5659
"""
@@ -69,8 +72,9 @@ def __init__(
6972
self.indices = np.arange(self.num_samples, dtype="int64")
7073

7174
self.augmentations = augmentations
72-
73-
self.shuffle()
75+
self._shuffle = shuffle
76+
if self._shuffle:
77+
self.shuffle()
7478

7579
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
7680
"""
@@ -122,7 +126,8 @@ def num_batches(self) -> int | None:
122126
return int(np.ceil(self.num_samples / self.batch_size))
123127

124128
def on_epoch_end(self) -> None:
125-
self.shuffle()
129+
if self._shuffle:
130+
self.shuffle()
126131

127132
def shuffle(self) -> None:
128133
"""Shuffle the dataset in-place."""

bayesflow/experimental/free_form_flow/free_form_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ def decode(z):
218218
return self.decode(z, conditions, training=stage == "training")
219219

220220
# VJP computation
221-
z, vjp_fn = vjp(encode, x)
221+
z, vjp_fn = vjp(encode, x, return_output=True)
222222
v1 = vjp_fn(v)[0]
223223
# JVP computation
224-
x_pred, v2 = jvp(decode, (z,), (v,))
224+
x_pred, v2 = jvp(decode, (z,), (v,), return_output=True)
225225

226226
# equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0]
227227
surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1)

tests/test_approximators/conftest.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def point_inference_network_with_multiple_parametric_scores():
6565

6666

6767
@pytest.fixture()
68-
def point_approximator(adapter, point_inference_network, summary_network):
68+
def point_approximator_with_single_parametric_score(adapter, point_inference_network, summary_network):
6969
from bayesflow import PointApproximator
7070

7171
return PointApproximator(
@@ -89,7 +89,18 @@ def point_approximator_with_multiple_parametric_scores(
8989

9090

9191
@pytest.fixture(
92-
params=["continuous_approximator", "point_approximator", "point_approximator_with_multiple_parametric_scores"],
92+
params=["point_approximator_with_single_parametric_score", "point_approximator_with_multiple_parametric_scores"]
93+
)
94+
def point_approximator(request):
95+
return request.getfixturevalue(request.param)
96+
97+
98+
@pytest.fixture(
99+
params=[
100+
"continuous_approximator",
101+
"point_approximator_with_single_parametric_score",
102+
"point_approximator_with_multiple_parametric_scores",
103+
],
93104
scope="function",
94105
)
95106
def approximator(request):
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
from bayesflow.scores import ParametricDistributionScore
3+
from tests.utils import check_combination_simulator_adapter
4+
5+
6+
def test_approximator_log_prob(point_approximator, simulator, batch_size, num_samples, adapter):
7+
check_combination_simulator_adapter(simulator, adapter)
8+
9+
data = simulator.sample((batch_size,))
10+
11+
batch = adapter(data)
12+
point_approximator.build_from_data(batch)
13+
14+
log_prob = point_approximator.log_prob(data=data)
15+
parametric_scores = [
16+
score
17+
for score in point_approximator.inference_network.scores.values()
18+
if isinstance(score, ParametricDistributionScore)
19+
]
20+
21+
if len(parametric_scores) > 1:
22+
assert isinstance(log_prob, dict)
23+
for score_key, score_log_prob in log_prob.items():
24+
assert isinstance(score_log_prob, np.ndarray)
25+
assert score_log_prob.shape == (batch_size,)
26+
27+
# If only one score is available, the outer nesting should be dropped.
28+
else:
29+
assert isinstance(log_prob, np.ndarray)
30+
assert log_prob.shape == (batch_size,)

tests/test_approximators/test_point_approximators/test_sample.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ def test_approximator_sample(point_approximator, simulator, batch_size, num_samp
2020

2121
assert isinstance(samples, dict)
2222

23-
print(keras.tree.map_structure(keras.ops.shape, samples))
24-
2523
# Expect doubly nested sample dictionary if more than one samplable score is available.
2624
scores_for_sampling = [
2725
score

tests/test_networks/test_inference_networks.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,12 @@ def test_save_and_load(tmp_path, inference_network, random_samples, random_condi
150150
loaded = keras.saving.load_model(tmp_path / "model.keras")
151151

152152
assert_layers_equal(inference_network, loaded)
153+
154+
155+
def test_compute_metrics(inference_network, random_samples, random_conditions):
156+
xz_shape = keras.ops.shape(random_samples)
157+
conditions_shape = keras.ops.shape(random_conditions) if random_conditions is not None else None
158+
inference_network.build(xz_shape, conditions_shape)
159+
160+
metrics = inference_network.compute_metrics(random_samples, conditions=random_conditions)
161+
assert "loss" in metrics

0 commit comments

Comments
 (0)