Skip to content

Commit c5fd4b3

Browse files
committed
Merge branch 'dev' into numpy-transforms
2 parents 94b8c99 + 74b9673 commit c5fd4b3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+4389
-47
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ Check out some of our walk-through notebooks below. We are actively working on p
100100
5. [Hyperparameter optimization](examples/Hyperparameter_Optimization.ipynb)
101101
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
102102
7. [Simple model comparison example (One-Sample T-Test)](examples/One_Sample_TTest.ipynb)
103-
8. More coming soon...
103+
8. [Rapid iteration with point estimation and expert statistics for Lotka-Volterra dynamics](examples/Lotka_Volterra_point_estimation_and_expert_stats.ipynb)
104+
9. More coming soon...
104105

105106
## Documentation \& Help
106107

bayesflow/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
workflows,
1212
utils,
1313
)
14+
1415
from .adapters import Adapter
15-
from .approximators import ContinuousApproximator
16+
from .approximators import ContinuousApproximator, PointApproximator
1617
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
1718
from .simulators import make_simulator
1819
from .workflows import BasicWorkflow

bayesflow/adapters/transforms/standardize.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
@serializable(package="bayesflow.adapters")
1212
class Standardize(ElementwiseTransform):
1313
"""
14-
Transform that when applied standardizes data using typical z-score standardization i.e. for some unstandardized
15-
data x the standardized version z would be
14+
Transform that when applied standardizes data using typical z-score standardization
15+
i.e. for some unstandardized data x the standardized version z would be
1616
1717
>>> z = (x - mean(x)) / std(x)
1818
@@ -27,6 +27,38 @@ class Standardize(ElementwiseTransform):
2727
standardization happens individually for each dimension
2828
momentum : float in (0,1)
2929
The momentum during training
30+
31+
Examples
32+
--------
33+
1) Standardize all variables using their individually estimated mean and stds.
34+
35+
>>> adapter = (
36+
bf.adapters.Adapter()
37+
.standardize()
38+
)
39+
40+
41+
2) Standardize all with same known mean and std.
42+
43+
>>> adapter = (
44+
bf.adapters.Adapter()
45+
.standardize(mean = 5, sd = 10)
46+
)
47+
48+
49+
3) Mix of fixed and estimated means/stds. Suppose we have priors for "beta" and "sigma" where we
50+
know the means and stds. However for all other variables, the means and stds are unknown.
51+
Then standardize should be used in several stages specifying which variables to include or exclude.
52+
53+
>>> adapter = (
54+
bf.adapters.Adapter()
55+
# mean fixed, std estimated
56+
.standardize(include = "beta", mean = 1)
57+
# both mean and SD fixed
58+
.standardize(include = "sigma", mean = 0.6, sd = 3)
59+
# both means and stds estimated for all other variables
60+
.standardize(exclude = ["beta", "sigma"])
61+
)
3062
"""
3163

3264
def __init__(

bayesflow/approximators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .approximator import Approximator
22
from .continuous_approximator import ContinuousApproximator
3+
from .point_approximator import PointApproximator
34
from .model_comparison_approximator import ModelComparisonApproximator
45

56
from ..utils._docs import _add_imports_to_all

bayesflow/approximators/continuous_approximator.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from bayesflow.adapters import Adapter
1212
from bayesflow.networks import InferenceNetwork, SummaryNetwork
1313
from bayesflow.types import Tensor
14-
from bayesflow.utils import filter_kwargs, logging, split_arrays
14+
from bayesflow.utils import filter_kwargs, logging, split_arrays, squeeze_inner_estimates_dict
1515
from .approximator import Approximator
1616

1717

@@ -120,6 +120,8 @@ def compute_metrics(
120120
else:
121121
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)
122122

123+
# Force a conversion to Tensor
124+
inference_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, inference_variables)
123125
inference_metrics = self.inference_network.compute_metrics(
124126
inference_variables, conditions=inference_conditions, stage=stage
125127
)
@@ -205,6 +207,44 @@ def get_config(self):
205207

206208
return base_config | config
207209

210+
def estimate(
211+
self,
212+
conditions: dict[str, np.ndarray],
213+
split: bool = False,
214+
estimators: dict[str, callable] = None,
215+
num_samples: int = 1000,
216+
**kwargs,
217+
) -> dict[str, dict[str, np.ndarray]]:
218+
estimators = estimators or {}
219+
estimators = (
220+
dict(
221+
mean=lambda x, axis: dict(value=np.mean(x, keepdims=True, axis=axis)),
222+
median=lambda x, axis: dict(value=np.median(x, keepdims=True, axis=axis)),
223+
quantiles=lambda x, axis: dict(value=np.moveaxis(np.quantile(x, q=[0.1, 0.5, 0.9], axis=axis), 0, 1)),
224+
)
225+
| estimators
226+
)
227+
228+
samples = self.sample(num_samples=num_samples, conditions=conditions, split=split, **kwargs)
229+
230+
estimates = {
231+
variable_name: {
232+
estimator_name: func(samples[variable_name], axis=1) for estimator_name, func in estimators.items()
233+
}
234+
for variable_name in samples.keys()
235+
}
236+
237+
# remove unnecessary nesting
238+
estimates = {
239+
variable_name: {
240+
outer_key: squeeze_inner_estimates_dict(estimates[variable_name][outer_key])
241+
for outer_key in estimates[variable_name].keys()
242+
}
243+
for variable_name in estimates.keys()
244+
}
245+
246+
return estimates
247+
208248
def sample(
209249
self,
210250
*,
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import keras
2+
import numpy as np
3+
from keras.saving import (
4+
register_keras_serializable as serializable,
5+
)
6+
7+
from bayesflow.types import Tensor
8+
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict
9+
from .continuous_approximator import ContinuousApproximator
10+
11+
12+
@serializable(package="bayesflow.approximators")
13+
class PointApproximator(ContinuousApproximator):
14+
"""
15+
A workflow for fast amortized point estimation of a conditional distribution.
16+
17+
The distribution is approximated by point estimators, parameterized by a feed-forward `PointInferenceNetwork`.
18+
Conditions can be compressed by an optional `SummaryNetwork` or used directly as input to the inference network.
19+
"""
20+
21+
def estimate(
22+
self,
23+
conditions: dict[str, np.ndarray],
24+
split: bool = False,
25+
**kwargs,
26+
) -> dict[str, dict[str, np.ndarray]]:
27+
conditions = self._prepare_conditions(conditions, **kwargs)
28+
estimates = self._estimate(**conditions, **kwargs)
29+
estimates = self._apply_inverse_adapter_to_estimates(estimates, **kwargs)
30+
# Optionally split the arrays along the last axis.
31+
if split:
32+
estimates = split_arrays(estimates, axis=-1)
33+
# Reorder the nested dictionary so that original variable names are at the top.
34+
estimates = self._reorder_estimates(estimates)
35+
# Remove unnecessary nesting.
36+
estimates = self._squeeze_estimates(estimates)
37+
38+
return estimates
39+
40+
def sample(
41+
self,
42+
*,
43+
num_samples: int,
44+
conditions: dict[str, np.ndarray],
45+
split: bool = False,
46+
**kwargs,
47+
) -> dict[str, np.ndarray]:
48+
conditions = self._prepare_conditions(conditions, **kwargs)
49+
samples = self._sample(num_samples, **conditions, **kwargs)
50+
samples = self._apply_inverse_adapter_to_samples(samples, **kwargs)
51+
# Optionally split the arrays along the last axis.
52+
if split:
53+
samples = split_arrays(samples, axis=-1)
54+
# Squeeze samples if there's only one key-value pair.
55+
samples = self._squeeze_samples(samples)
56+
57+
return samples
58+
59+
def _prepare_conditions(self, conditions: dict[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
60+
"""Adapts and converts the conditions to tensors."""
61+
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
62+
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
63+
64+
def _apply_inverse_adapter_to_estimates(
65+
self, estimates: dict[str, dict[str, Tensor]], **kwargs
66+
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
67+
"""Applies the inverse adapter on each inner element of the _estimate output dictionary."""
68+
estimates = keras.tree.map_structure(keras.ops.convert_to_numpy, estimates)
69+
processed = {}
70+
for score_key, score_val in estimates.items():
71+
processed[score_key] = {}
72+
for head_key, estimate in score_val.items():
73+
adapted = self.adapter(
74+
{"inference_variables": estimate},
75+
inverse=True,
76+
strict=False,
77+
**kwargs,
78+
)
79+
processed[score_key][head_key] = adapted
80+
return processed
81+
82+
def _apply_inverse_adapter_to_samples(
83+
self, samples: dict[str, Tensor], **kwargs
84+
) -> dict[str, dict[str, np.ndarray]]:
85+
"""Applies the inverse adapter to a dictionary of samples."""
86+
samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples)
87+
processed = {}
88+
for score_key, samples in samples.items():
89+
processed[score_key] = self.adapter(
90+
{"inference_variables": samples},
91+
inverse=True,
92+
strict=False,
93+
**kwargs,
94+
)
95+
return processed
96+
97+
def _reorder_estimates(
98+
self, estimates: dict[str, dict[str, dict[str, np.ndarray]]]
99+
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
100+
"""Reorders the nested dictionary so that the inference variable names become the top-level keys."""
101+
# Grab the variable names from one sample inner dictionary.
102+
sample_inner = next(iter(next(iter(estimates.values())).values()))
103+
variable_names = sample_inner.keys()
104+
reordered = {}
105+
for variable in variable_names:
106+
reordered[variable] = {}
107+
for score_key, inner_dict in estimates.items():
108+
reordered[variable][score_key] = {inner_key: value[variable] for inner_key, value in inner_dict.items()}
109+
return reordered
110+
111+
def _squeeze_estimates(
112+
self, estimates: dict[str, dict[str, dict[str, np.ndarray]]]
113+
) -> dict[str, dict[str, np.ndarray]]:
114+
"""Squeezes each inner estimate dictionary to remove unnecessary nesting."""
115+
squeezed = {}
116+
for variable, variable_estimates in estimates.items():
117+
squeezed[variable] = {
118+
score_key: squeeze_inner_estimates_dict(inner_estimate)
119+
for score_key, inner_estimate in variable_estimates.items()
120+
}
121+
return squeezed
122+
123+
def _squeeze_samples(self, samples: dict[str, np.ndarray]) -> np.ndarray or dict[str, np.ndarray]:
124+
"""Squeezes the samples dictionary to just the value if there is only one key-value pair."""
125+
if len(samples) == 1:
126+
return next(iter(samples.values())) # Extract and return the only item's value
127+
return samples
128+
129+
def _estimate(
130+
self,
131+
inference_conditions: Tensor = None,
132+
summary_variables: Tensor = None,
133+
**kwargs,
134+
) -> dict[str, dict[str, Tensor]]:
135+
if self.summary_network is None:
136+
if summary_variables is not None:
137+
raise ValueError("Cannot use summary variables without a summary network.")
138+
else:
139+
if summary_variables is None:
140+
raise ValueError("Summary variables are required when a summary network is present.")
141+
142+
summary_outputs = self.summary_network(
143+
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
144+
)
145+
146+
if inference_conditions is None:
147+
inference_conditions = summary_outputs
148+
else:
149+
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=1)
150+
151+
return self.inference_network(
152+
conditions=inference_conditions,
153+
**filter_kwargs(kwargs, self.inference_network.call),
154+
)

bayesflow/diagnostics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .plots import (
44
calibration_ecdf,
5+
calibration_ecdf_from_quantiles,
56
calibration_histogram,
67
loss,
78
mc_calibration,
@@ -10,6 +11,7 @@
1011
pairs_posterior,
1112
pairs_samples,
1213
recovery,
14+
recovery_from_estimates,
1315
z_score_contraction,
1416
)
1517

bayesflow/diagnostics/plots/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .calibration_ecdf import calibration_ecdf
2+
from .calibration_ecdf_from_quantiles import calibration_ecdf_from_quantiles
23
from .calibration_histogram import calibration_histogram
34
from .loss import loss
45
from .mc_calibration import mc_calibration
@@ -7,4 +8,5 @@
78
from .pairs_posterior import pairs_posterior
89
from .pairs_samples import pairs_samples
910
from .recovery import recovery
11+
from .recovery_from_estimates import recovery_from_estimates
1012
from .z_score_contraction import z_score_contraction

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,12 @@ def calibration_ecdf(
163163
plot_data["axes"].flat[j].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDF")
164164

165165
# Compute uniform ECDF and bands
166-
alpha, z, L, H = simultaneous_ecdf_bands(estimates.shape[0], **kwargs.pop("ecdf_bands_kwargs", {}))
166+
alpha, z, L, U = simultaneous_ecdf_bands(estimates.shape[0], **kwargs.pop("ecdf_bands_kwargs", {}))
167167

168168
# Difference, if specified
169169
if difference:
170170
L -= z
171-
H -= z
171+
U -= z
172172
ylab = "ECDF Difference"
173173
else:
174174
ylab = "ECDF"
@@ -182,7 +182,7 @@ def calibration_ecdf(
182182
titles = ["Stacked ECDFs"]
183183

184184
for ax, title in zip(plot_data["axes"].flat, titles):
185-
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
185+
ax.fill_between(z, L, U, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
186186
ax.legend(fontsize=legend_fontsize)
187187
ax.set_title(title, fontsize=title_fontsize)
188188

0 commit comments

Comments
 (0)