Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
093785d
Draft implementation of quantile estimation
han-ol Dec 12, 2024
aa5dd93
Add notebook showing basic quantile estimation
han-ol Dec 12, 2024
60050ec
Merge branch 'dev' into point-estimation
han-ol Jan 6, 2025
0cd3110
automatic head building for multiple scoring rules
han-ol Jan 8, 2025
98437ae
Update notebook to show customizing the PointInferenceNetwork with mu…
han-ol Jan 8, 2025
ce07855
Add activation function for quantile estimation
han-ol Jan 10, 2025
1a85fc8
Merge branch 'dev' into point-estimation
han-ol Jan 28, 2025
ecef8ed
Refactor scores
han-ol Jan 28, 2025
ab821bf
Refactor links [no ci]
han-ol Jan 28, 2025
13ce858
Refactor automatic head building, add serialization, tests, unconditi…
han-ol Feb 6, 2025
daf69f6
Fix: Avoid shared attributes due to mutable default keyword arguments
han-ol Feb 7, 2025
3d639ed
Rely on keras to build PointInferenceNetwork and warn if impossible
han-ol Feb 10, 2025
21fb5f2
Estimate method for Point- and ContinuousApproximator
han-ol Feb 11, 2025
61267b9
Merge branch 'dev' into point-estimation
han-ol Feb 11, 2025
63d448b
Sqeeze head shape for normed difference scores
han-ol Feb 12, 2025
b69df85
Rename target and reference to estimates and targets; add optional we…
han-ol Feb 12, 2025
ba5cc67
Remove regressor draft
han-ol Feb 12, 2025
3d678cf
fix: target should be targets
han-ol Feb 12, 2025
5cb5cc9
Quantile loss aggregation
han-ol Feb 12, 2025
dac7e9b
Merge branch 'dev' into point-estimation
han-ol Feb 13, 2025
50e01b1
Change ruff version so lint test can pass
han-ol Feb 13, 2025
95220cd
Merge branch 'dev' into point-estimation
han-ol Feb 14, 2025
72bb994
Merge branch 'dev' into point-estimation
han-ol Feb 18, 2025
1932724
Fix #322: docstring and refactor for calibration_ecdf
han-ol Feb 18, 2025
38fe5bf
Recovery and calibration diagnostics applicable on point estimates
han-ol Feb 18, 2025
149d1ad
Pointwise confidence bands, change confidence and add label
han-ol Feb 28, 2025
887be8f
MultivariateNormal sample fix, training still unstable, added warning
han-ol Feb 28, 2025
849c67c
Support point inference for BasicWorkflow
han-ol Mar 3, 2025
d8dc9a8
fix kwargs in estimate
han-ol Mar 3, 2025
ff844fe
Lotka-Volterra tutorial
han-ol Mar 3, 2025
e67e91b
Point estimation in regression starter notebook (without narration)
han-ol Mar 3, 2025
16a4d0a
Merge branch 'dev' into point-estimation
han-ol Mar 3, 2025
1e92c8d
Rename example notebook to use underscore
han-ol Mar 3, 2025
c051fcc
Link Lotka-Volterra notebook
han-ol Mar 3, 2025
a9f10c3
fix fixture: typical_point_inference_network_subnet
han-ol Mar 3, 2025
69da9c8
Diagonal normal posterior predictive check for quantile estimates
han-ol Mar 4, 2025
b29d3fa
Update Lotka-Volterra links and names
han-ol Mar 4, 2025
c9d6f3c
Merge branch 'dev' into point-estimation
han-ol Mar 4, 2025
2ac5f3c
fix typos
han-ol Mar 5, 2025
95cd950
Popultate __all__ for scores and links
han-ol Mar 5, 2025
d76ca9f
Docs for scores module [no-ci]
han-ol Mar 5, 2025
18f3a6e
[no-ci] Some docstrings
han-ol Mar 11, 2025
7955cbd
[no-ci] fix: filter kwargs of InferenceNetwork
han-ol Mar 11, 2025
39f17ee
Spelling, inline _forward and comments
han-ol Mar 11, 2025
d5cb12e
[no-ci] spelling
han-ol Mar 11, 2025
aa20868
Merge branch 'dev' into point-estimation
han-ol Mar 13, 2025
c2ca810
Refactor to remove set_head_shapes_from_target_shape
han-ol Mar 13, 2025
26f6499
Docs for building PointInferenceNetwork from scores
han-ol Mar 13, 2025
3f4e60d
[no-ci] Short doc string for OrderedQuantiles
han-ol Mar 13, 2025
e2e3b72
[no-ci] Key assignment in compute_metrics
han-ol Mar 13, 2025
23c1dd8
[no-ci] Don't add identity in postitive semi definite link
han-ol Mar 13, 2025
a80011c
Move part of nested dictionary operations to dict_utils
han-ol Mar 13, 2025
8aebe6a
Docstring for links module
han-ol Mar 13, 2025
0035e3d
Refactor mean over scores
han-ol Mar 13, 2025
06ea8df
Remove draft notebook
han-ol Mar 13, 2025
3d56323
Minor change to comment
han-ol Mar 13, 2025
5a2ea53
Two proposals to include the names of head layers in serialization
han-ol Mar 13, 2025
40ccf08
More narration for Lotka-Volterra notebook
han-ol Mar 14, 2025
011f0ce
Fix typo
stefanradev93 Mar 14, 2025
4c0794f
Refactor scores
stefanradev93 Mar 14, 2025
8de49ac
Make abstract scores importable
han-ol Mar 14, 2025
94720f2
Refactor nested processing of estimates
han-ol Mar 14, 2025
bff8d20
Remove comments wrt serialization proposal and fix score name change
han-ol Mar 14, 2025
50167b3
Sample from parametric scoring rules; more refactoring in PointApprox…
han-ol Mar 14, 2025
23790c6
More detailed test for point inference network serialization
han-ol Mar 14, 2025
44105b7
Remove config comparison for heads
han-ol Mar 14, 2025
bd812ef
More narration for Lotka-Volterra notebook
han-ol Mar 14, 2025
c99fd01
Force conversion to Tensor of inference variables
han-ol Mar 14, 2025
a55612e
Fix quantile level serialization and add save/load to notebook
han-ol Mar 15, 2025
7114293
Formatting fix in notebook
han-ol Mar 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ Check out some of our walk-through notebooks below. We are actively working on p
5. [Hyperparameter optimization](examples/Hyperparameter_Optimization.ipynb)
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
7. [Simple model comparison example (One-Sample T-Test)](examples/One_Sample_TTest.ipynb)
8. More coming soon...
8. [Rapid iteration with point estimation and expert statistics for Lotka-Volterra dynamics](examples/Lotka_Volterra_point_estimation_and_expert_stats.ipynb)
9. More coming soon...

## Documentation \& Help

Expand Down
3 changes: 2 additions & 1 deletion bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
workflows,
utils,
)

from .adapters import Adapter
from .approximators import ContinuousApproximator
from .approximators import ContinuousApproximator, PointApproximator
from .datasets import OfflineDataset, OnlineDataset, DiskDataset
from .simulators import make_simulator
from .workflows import BasicWorkflow
Expand Down
1 change: 1 addition & 0 deletions bayesflow/approximators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .approximator import Approximator
from .continuous_approximator import ContinuousApproximator
from .point_approximator import PointApproximator
from .model_comparison_approximator import ModelComparisonApproximator

from ..utils._docs import _add_imports_to_all
Expand Down
42 changes: 41 additions & 1 deletion bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from bayesflow.adapters import Adapter
from bayesflow.networks import InferenceNetwork, SummaryNetwork
from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs, logging, split_arrays
from bayesflow.utils import filter_kwargs, logging, split_arrays, squeeze_inner_estimates_dict
from .approximator import Approximator


Expand Down Expand Up @@ -120,6 +120,8 @@ def compute_metrics(
else:
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=-1)

# Force a conversion to Tensor
inference_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, inference_variables)
inference_metrics = self.inference_network.compute_metrics(
inference_variables, conditions=inference_conditions, stage=stage
)
Expand Down Expand Up @@ -205,6 +207,44 @@ def get_config(self):

return base_config | config

def estimate(
self,
conditions: dict[str, np.ndarray],
split: bool = False,
estimators: dict[str, callable] = None,
num_samples: int = 1000,
**kwargs,
) -> dict[str, dict[str, np.ndarray]]:
estimators = estimators or {}
estimators = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very strongly nested, and I currently don't understand the reason for this. Keep in mind that objects you return to the user should be somewhat simple, otherwise users will get confused.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ContinuousApproximator.estimator method is (just) a convenience method to bring ease interoperability with the PointApproximator.estimate method.

Here, the nesting is irreducible if we want output to be the same as with the PointApproximator.estimate method.

I commented on nesting complexity over there (PointApproximator.estimate) since this is where the original estimation takes place and I believe there is the real question of whether to refactor with respect to nesting.

dict(
mean=lambda x, axis: dict(value=np.mean(x, keepdims=True, axis=axis)),
median=lambda x, axis: dict(value=np.median(x, keepdims=True, axis=axis)),
quantiles=lambda x, axis: dict(value=np.moveaxis(np.quantile(x, q=[0.1, 0.5, 0.9], axis=axis), 0, 1)),
)
| estimators
)

samples = self.sample(num_samples=num_samples, conditions=conditions, split=split, **kwargs)

estimates = {
variable_name: {
estimator_name: func(samples[variable_name], axis=1) for estimator_name, func in estimators.items()
}
for variable_name in samples.keys()
}

# remove unnecessary nesting
estimates = {
variable_name: {
outer_key: squeeze_inner_estimates_dict(estimates[variable_name][outer_key])
for outer_key in estimates[variable_name].keys()
}
for variable_name in estimates.keys()
}

return estimates

def sample(
self,
*,
Expand Down
154 changes: 154 additions & 0 deletions bayesflow/approximators/point_approximator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import keras
import numpy as np
from keras.saving import (
register_keras_serializable as serializable,
)

from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict
from .continuous_approximator import ContinuousApproximator


@serializable(package="bayesflow.approximators")
class PointApproximator(ContinuousApproximator):
"""
A workflow for fast amortized point estimation of a conditional distribution.

The distribution is approximated by point estimators, parameterized by a feed-forward `PointInferenceNetwork`.
Conditions can be compressed by an optional `SummaryNetwork` or used directly as input to the inference network.
"""

def estimate(
Copy link
Collaborator Author

@han-ol han-ol Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My two cents about nesting complexity.

  1. Inherent complexity of output

The output structure of PointApproximator.estimate should identify the variable names and the point estimate kinds. To be as close as possible to the ContinuousApproximator.sample output I am quite happy with a variable name major nesting.

Replacing tensors with their shapes this looks like the following

{'alpha': {'mean': (500, 1), 'quantiles': (500, 5, 1)},
 'beta': {'mean': (500, 1), 'quantiles': (500, 5, 1)},
 'gamma': {'mean': (500, 1), 'quantiles': (500, 5, 1)},
 'delta': {'mean': (500, 1), 'quantiles': (500, 5, 1)}}

To see it in context check the notebook
https://github.com/han-ol/bayesflow/blob/point-estimation/examples/Lotka_Volterra_point_estimation_and_expert_stats.ipynb

  1. code readability / implementation complexity

Having the functionality in the first place requires a few computations that are nested. Maybe some nice utility functions can aid here, but before coming up with those I wanted to err on the side of not hiding how it works. If it is not readable I should improve the comments. What do you think?

Copy link
Contributor

@LarsKue LarsKue Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with doubly nested outputs as you show them, assuming they are necessary, and we cannot efficiently do something like estimate_mean and estimate_quantiles instead. However, internally, you have a triply nested dictionary here, and I do not yet see the reason for it.

For readability, a loop might be preferred over dictionary comprehension.

Copy link
Collaborator Author

@han-ol han-ol Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Triple is necessary for scores that need multiple estimates. If you estimate mean and (co)variance for example.

self,
conditions: dict[str, np.ndarray],
split: bool = False,
**kwargs,
) -> dict[str, dict[str, np.ndarray]]:
conditions = self._prepare_conditions(conditions, **kwargs)
estimates = self._estimate(**conditions, **kwargs)
estimates = self._apply_inverse_adapter_to_estimates(estimates, **kwargs)
# Optionally split the arrays along the last axis.
if split:
estimates = split_arrays(estimates, axis=-1)
# Reorder the nested dictionary so that original variable names are at the top.
estimates = self._reorder_estimates(estimates)
# Remove unnecessary nesting.
estimates = self._squeeze_estimates(estimates)

return estimates

def sample(
self,
*,
num_samples: int,
conditions: dict[str, np.ndarray],
split: bool = False,
**kwargs,
) -> dict[str, np.ndarray]:
conditions = self._prepare_conditions(conditions, **kwargs)
samples = self._sample(num_samples, **conditions, **kwargs)
samples = self._apply_inverse_adapter_to_samples(samples, **kwargs)
# Optionally split the arrays along the last axis.
if split:
samples = split_arrays(samples, axis=-1)
# Squeeze samples if there's only one key-value pair.
samples = self._squeeze_samples(samples)

return samples

def _prepare_conditions(self, conditions: dict[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
"""Adapts and converts the conditions to tensors."""
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)

def _apply_inverse_adapter_to_estimates(
self, estimates: dict[str, dict[str, Tensor]], **kwargs
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
"""Applies the inverse adapter on each inner element of the _estimate output dictionary."""
estimates = keras.tree.map_structure(keras.ops.convert_to_numpy, estimates)
processed = {}
for score_key, score_val in estimates.items():
processed[score_key] = {}
for head_key, estimate in score_val.items():
adapted = self.adapter(
{"inference_variables": estimate},
inverse=True,
strict=False,
**kwargs,
)
processed[score_key][head_key] = adapted
return processed

def _apply_inverse_adapter_to_samples(
self, samples: dict[str, Tensor], **kwargs
) -> dict[str, dict[str, np.ndarray]]:
"""Applies the inverse adapter to a dictionary of samples."""
samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples)
processed = {}
for score_key, samples in samples.items():
processed[score_key] = self.adapter(
{"inference_variables": samples},
inverse=True,
strict=False,
**kwargs,
)
return processed

def _reorder_estimates(
self, estimates: dict[str, dict[str, dict[str, np.ndarray]]]
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
"""Reorders the nested dictionary so that the inference variable names become the top-level keys."""
# Grab the variable names from one sample inner dictionary.
sample_inner = next(iter(next(iter(estimates.values())).values()))
variable_names = sample_inner.keys()
reordered = {}
for variable in variable_names:
reordered[variable] = {}
for score_key, inner_dict in estimates.items():
reordered[variable][score_key] = {inner_key: value[variable] for inner_key, value in inner_dict.items()}
return reordered

def _squeeze_estimates(
self, estimates: dict[str, dict[str, dict[str, np.ndarray]]]
) -> dict[str, dict[str, np.ndarray]]:
"""Squeezes each inner estimate dictionary to remove unnecessary nesting."""
squeezed = {}
for variable, variable_estimates in estimates.items():
squeezed[variable] = {
score_key: squeeze_inner_estimates_dict(inner_estimate)
for score_key, inner_estimate in variable_estimates.items()
}
return squeezed

def _squeeze_samples(self, samples: dict[str, np.ndarray]) -> np.ndarray or dict[str, np.ndarray]:
"""Squeezes the samples dictionary to just the value if there is only one key-value pair."""
if len(samples) == 1:
return next(iter(samples.values())) # Extract and return the only item's value
return samples

def _estimate(
self,
inference_conditions: Tensor = None,
summary_variables: Tensor = None,
**kwargs,
) -> dict[str, dict[str, Tensor]]:
if self.summary_network is None:
if summary_variables is not None:
raise ValueError("Cannot use summary variables without a summary network.")
else:
if summary_variables is None:
raise ValueError("Summary variables are required when a summary network is present.")

summary_outputs = self.summary_network(
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
)

if inference_conditions is None:
inference_conditions = summary_outputs
else:
inference_conditions = keras.ops.concatenate([inference_conditions, summary_outputs], axis=1)

return self.inference_network(
conditions=inference_conditions,
**filter_kwargs(kwargs, self.inference_network.call),
)
2 changes: 2 additions & 0 deletions bayesflow/diagnostics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .plots import (
calibration_ecdf,
calibration_ecdf_from_quantiles,
calibration_histogram,
loss,
mc_calibration,
Expand All @@ -10,6 +11,7 @@
pairs_posterior,
pairs_samples,
recovery,
recovery_from_estimates,
z_score_contraction,
)

Expand Down
2 changes: 2 additions & 0 deletions bayesflow/diagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .calibration_ecdf import calibration_ecdf
from .calibration_ecdf_from_quantiles import calibration_ecdf_from_quantiles
from .calibration_histogram import calibration_histogram
from .loss import loss
from .mc_calibration import mc_calibration
Expand All @@ -7,4 +8,5 @@
from .pairs_posterior import pairs_posterior
from .pairs_samples import pairs_samples
from .recovery import recovery
from .recovery_from_estimates import recovery_from_estimates
from .z_score_contraction import z_score_contraction
6 changes: 3 additions & 3 deletions bayesflow/diagnostics/plots/calibration_ecdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,12 @@ def calibration_ecdf(
plot_data["axes"].flat[j].plot(xx, yy, color=rank_ecdf_color, alpha=0.95, label="Rank ECDF")

# Compute uniform ECDF and bands
alpha, z, L, H = simultaneous_ecdf_bands(estimates.shape[0], **kwargs.pop("ecdf_bands_kwargs", {}))
alpha, z, L, U = simultaneous_ecdf_bands(estimates.shape[0], **kwargs.pop("ecdf_bands_kwargs", {}))

# Difference, if specified
if difference:
L -= z
H -= z
U -= z
ylab = "ECDF Difference"
else:
ylab = "ECDF"
Expand All @@ -182,7 +182,7 @@ def calibration_ecdf(
titles = ["Stacked ECDFs"]

for ax, title in zip(plot_data["axes"].flat, titles):
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
ax.fill_between(z, L, U, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
ax.legend(fontsize=legend_fontsize)
ax.set_title(title, fontsize=title_fontsize)

Expand Down
Loading