Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
51249c0
- Add `mmd_hypothesis_test.py` module to the `diagnostics/metrics` pa…
thegialeo Apr 1, 2025
e04ea09
draft implementation for mmd_hypothesis_test_from_summaries() -> TODO…
thegialeo Apr 1, 2025
f01c5db
draft implementation for mmd_hypothesis_test() -> TODO: approximator.…
thegialeo Apr 1, 2025
f843462
add draft module docstring
thegialeo Apr 1, 2025
927072f
run pre-commit
thegialeo Apr 1, 2025
0333c77
add paper reference to module docstring of diagnostics/metrics/mmd_hy…
thegialeo Apr 2, 2025
48c2f8c
add type casting between np.ndarray and bf.types.Tensor with keras.ops
thegialeo Apr 2, 2025
4e91154
change functions names to have compute_ prefix (see PR comments)
thegialeo Apr 2, 2025
d6aa068
move module docstring to function docstrings
thegialeo Apr 2, 2025
76f88b5
import compute_mmd_hypothesis_test and compute_mmd_hypothesis_test_fr…
thegialeo Apr 2, 2025
f2f863c
- update compute_mmd_hypothesis_test_from_summaries implementation
thegialeo Apr 2, 2025
d8cadf9
update implementation of compute_mmd_hypothesis_test + restrict appro…
thegialeo Apr 2, 2025
c8f54c0
add unit test for output shape of compute_mmd_hypothesis_test
thegialeo Apr 2, 2025
ef9dd00
handle case for when ContinuousApproximator.summary_network is None
thegialeo Apr 2, 2025
f5687e8
add unit test case for when ContinuousApproximator.summary_network = …
thegialeo Apr 2, 2025
5b502af
unit test for computed MMD values to be positive
thegialeo Apr 2, 2025
ec1e1d3
add test cases for same and different distributions mmd hypothesis te…
thegialeo Apr 2, 2025
8bacff9
add Raises for unequal shapes of observed and reference summaries on …
thegialeo Apr 2, 2025
454304b
add Raises to docstring
thegialeo Apr 2, 2025
ab0b895
add test cases for Raises
thegialeo Apr 2, 2025
6fa1925
add test cases for indirect Raises through compute_hypothesis_test
thegialeo Apr 2, 2025
3c6c33c
add Raises to compute_hypothesis_test for unmatching observed and ref…
thegialeo Apr 2, 2025
c2e52c4
rename unit test functions
thegialeo Apr 2, 2025
b29a365
remove transitive Raises unit test to avoid coupling in the testing
thegialeo Apr 3, 2025
f67da8a
adjust mock summary_network in unit tests to be a deterministic trans…
thegialeo Apr 3, 2025
dbdbe2e
allow approximator argument to also be of bayesflow.network.SummaryNe…
thegialeo Apr 3, 2025
946e405
add Raise to docstring if approximator is not ContinuousApproximator …
thegialeo Apr 3, 2025
dd725b7
add test cases for when approximator argument is bayesflow.networks.S…
thegialeo Apr 3, 2025
cd3e0f6
edit docstring
thegialeo Apr 3, 2025
d0667ac
correct bug: exception should not be raised for num_null_samples > nu…
thegialeo Apr 4, 2025
025c8bd
- raise exception for num_null_samples zero or negative
thegialeo Apr 4, 2025
6600e14
adjust unit tests
thegialeo Apr 6, 2025
5e3f5f1
- Add `mmd_hypothesis_test.py` module to the `diagnostics/metrics` pa…
thegialeo Apr 1, 2025
3ef4bb8
draft implementation for mmd_hypothesis_test_from_summaries() -> TODO…
thegialeo Apr 1, 2025
fe86176
draft implementation for mmd_hypothesis_test() -> TODO: approximator.…
thegialeo Apr 1, 2025
e88e7b0
add draft module docstring
thegialeo Apr 1, 2025
72dce22
run pre-commit
thegialeo Apr 1, 2025
2a0a0e4
add paper reference to module docstring of diagnostics/metrics/mmd_hy…
thegialeo Apr 2, 2025
6734016
add type casting between np.ndarray and bf.types.Tensor with keras.ops
thegialeo Apr 2, 2025
52f6ca9
change functions names to have compute_ prefix (see PR comments)
thegialeo Apr 2, 2025
34be3c3
move module docstring to function docstrings
thegialeo Apr 2, 2025
7429ac2
resolve merge conflict
thegialeo Apr 17, 2025
e051e5d
- update compute_mmd_hypothesis_test_from_summaries implementation
thegialeo Apr 2, 2025
2ace242
update implementation of compute_mmd_hypothesis_test + restrict appro…
thegialeo Apr 2, 2025
0c1e973
add unit test for output shape of compute_mmd_hypothesis_test
thegialeo Apr 2, 2025
030a003
handle case for when ContinuousApproximator.summary_network is None
thegialeo Apr 2, 2025
ac38528
add unit test case for when ContinuousApproximator.summary_network = …
thegialeo Apr 2, 2025
94066eb
unit test for computed MMD values to be positive
thegialeo Apr 2, 2025
bfbfbb1
add test cases for same and different distributions mmd hypothesis te…
thegialeo Apr 2, 2025
df6c5b1
add Raises for unequal shapes of observed and reference summaries on …
thegialeo Apr 2, 2025
7265e4b
add Raises to docstring
thegialeo Apr 2, 2025
a65b398
add test cases for Raises
thegialeo Apr 2, 2025
70f1f41
add test cases for indirect Raises through compute_hypothesis_test
thegialeo Apr 2, 2025
7458a8b
add Raises to compute_hypothesis_test for unmatching observed and ref…
thegialeo Apr 2, 2025
6c61996
rename unit test functions
thegialeo Apr 2, 2025
7162725
remove transitive Raises unit test to avoid coupling in the testing
thegialeo Apr 3, 2025
4ec5e95
adjust mock summary_network in unit tests to be a deterministic trans…
thegialeo Apr 3, 2025
d9b62e5
allow approximator argument to also be of bayesflow.network.SummaryNe…
thegialeo Apr 3, 2025
e11c36d
add Raise to docstring if approximator is not ContinuousApproximator …
thegialeo Apr 3, 2025
a09a9bc
add test cases for when approximator argument is bayesflow.networks.S…
thegialeo Apr 3, 2025
3abcef5
edit docstring
thegialeo Apr 3, 2025
e727c7a
update docstrings
thegialeo Apr 17, 2025
6d001ec
create bootstrap_comparison signature as proposed by PR comments
thegialeo Apr 17, 2025
51c35db
implement bootstrap_comparison
thegialeo Apr 17, 2025
ca449b7
rename compute_mmd_hypothesis_test to mmd_comparison_from_summaries a…
thegialeo Apr 17, 2025
bcf298b
rename compute_mmd_hypothesis_test to mmd_comparison
thegialeo Apr 17, 2025
67b7eec
add bootstrap_comparison import + import of renamed functions
thegialeo Apr 17, 2025
456dd9b
adjust unit tests for compute_hypothesis_test_from_summaries to new m…
thegialeo Apr 17, 2025
4d621f0
adjust unit tests for compute_hypothesis_test to new mmd_comparison f…
thegialeo Apr 17, 2025
78148a5
add unit test for exception raise if approximator is not ContinuousAp…
thegialeo Apr 17, 2025
5a78654
correct unit testX test_mmd_comparison_approximator_incorrect_instance
thegialeo Apr 17, 2025
ca6727e
add unit test for bootstrap_comparison output shape
thegialeo Apr 17, 2025
2d7f9b1
add unit tests for bootstrap_comparison with simple distributions (sa…
thegialeo Apr 17, 2025
0c006ed
add unit test for bootstrap_comparison exception raise for mismatched…
thegialeo Apr 17, 2025
1076006
add unit test for bootstrap_comparison exception raise for num_null_s…
thegialeo Apr 17, 2025
a35bbb3
correct mistake: bootstrap_comparison should raise exception for num_…
thegialeo Apr 17, 2025
377d88d
remove legacy unit tests
thegialeo Apr 17, 2025
13f457c
see changes in exception raises of previous commit
thegialeo Apr 17, 2025
4dac297
Merge branch 'feat-hypothesis-test' of github.com:thegialeo/bayesflow…
thegialeo Apr 17, 2025
4e20f4a
Merge branch 'bayesflow-org:main' into feat-hypothesis-test
thegialeo Apr 17, 2025
f328aea
update module docstring
thegialeo Apr 17, 2025
3d811c8
Merge branch 'dev' into feat-hypothesis-test
vpratz Apr 22, 2025
39144ed
formatting: remove trailing whitespace
vpratz Apr 22, 2025
5bf5cb1
formatting: remove large headings in tests
vpratz Apr 22, 2025
a39ac6e
adjust mock comparison_fn in unit tests of bootstrap_comparison to ac…
thegialeo Apr 22, 2025
1485e82
formatting to avoid too long lines
thegialeo Apr 22, 2025
0db4b5f
[no ci] Merge remote-tracking branch 'upstream/dev' into feat-hypothe…
vpratz Apr 25, 2025
5f2625a
Add summary_outputs method to approximator.
vpratz Apr 25, 2025
b18504b
refactor summary space distance function
vpratz Apr 25, 2025
7ee14fa
Rename mmd_hypothesis_test.py to model_misspecification.py
vpratz Apr 25, 2025
8c363fc
Merge remote-tracking branch 'upstream/dev' into feat-hypothesis-test
vpratz Apr 28, 2025
eba5f19
Merge remote-tracking branch 'upstream/dev' into feat-hypothesis-test
vpratz Apr 28, 2025
87c061c
add test for case summary_network=None in summary_space_comparison
vpratz Apr 28, 2025
e23206d
rename summary_outputs to summaries
vpratz Apr 30, 2025
9a5212d
Merge remote-tracking branch 'upstream/dev' into feat-hypothesis-test
vpratz May 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,39 @@ def _sample(
**filter_kwargs(kwargs, self.inference_network.sample),
)

def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
"""
Computes the summaries of given data.

The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.

Parameters
----------
data : Mapping[str, np.ndarray]
Dictionary of data as NumPy arrays.
**kwargs : dict
Additional keyword arguments for the adapter and the summary network.

Returns
-------
summaries : np.ndarray
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`

Raises
------
ValueError
If the approximator does not have a summary network, or the adapter does not produce the output required
by the summary network.
"""
if self.summary_network is None:
raise ValueError("A summary network is required to compute summeries.")
data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
raise ValueError("Summary variables are required to compute summaries.")
summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
return summaries

def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]:
"""
Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the
Expand Down
33 changes: 33 additions & 0 deletions bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,36 @@ def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tens
output = self.logits_projector(output)

return output

def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
"""
Computes the summaries of given data.

The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.

Parameters
----------
data : Mapping[str, np.ndarray]
Dictionary of data as NumPy arrays.
**kwargs : dict
Additional keyword arguments for the adapter and the summary network.

Returns
-------
summaries : np.ndarray
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`

Raises
------
ValueError
If the approximator does not have a summary network, or the adapter does not produce the output required
by the summary network.
"""
if self.summary_network is None:
raise ValueError("A summary network is required to compute summaries.")
data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
raise ValueError("Summary variables are required to compute summaries.")
summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
return summaries
9 changes: 7 additions & 2 deletions bayesflow/diagnostics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""
r"""
A collection of plotting utilities and metrics for evaluating trained :py:class:`~bayesflow.workflows.Workflow`\ s.
"""

from .metrics import root_mean_squared_error, calibration_error, posterior_contraction
from .metrics import (
bootstrap_comparison,
calibration_error,
posterior_contraction,
summary_space_comparison,
)

from .plots import (
calibration_ecdf,
Expand Down
1 change: 1 addition & 0 deletions bayesflow/diagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .root_mean_squared_error import root_mean_squared_error
from .expected_calibration_error import expected_calibration_error
from .classifier_two_sample_test import classifier_two_sample_test
from .model_misspecification import bootstrap_comparison, summary_space_comparison
155 changes: 155 additions & 0 deletions bayesflow/diagnostics/metrics/model_misspecification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
This module provides functions for computing distances between observation samples and reference samples with distance
distributions within the reference samples for hypothesis testing.
"""

from collections.abc import Mapping, Callable

import numpy as np
from keras.ops import convert_to_numpy, convert_to_tensor

from bayesflow.approximators import ContinuousApproximator
from bayesflow.metrics.functional import maximum_mean_discrepancy
from bayesflow.types import Tensor


def bootstrap_comparison(
observed_samples: np.ndarray,
reference_samples: np.ndarray,
comparison_fn: Callable[[Tensor, Tensor], Tensor],
num_null_samples: int = 100,
) -> tuple[float, np.ndarray]:
"""Computes the distance between observed and reference samples and generates a distribution of null sample
distances by bootstrapping for hypothesis testing.

Parameters
----------
observed_samples : np.ndarray)
Observed samples, shape (num_observed, ...).
reference_samples : np.ndarray
Reference samples, shape (num_reference, ...).
comparison_fn : Callable[[Tensor, Tensor], Tensor]
Function to compute the distance metric.
num_null_samples : int
Number of null samples to generate for hypothesis testing. Default is 100.

Returns
-------
distance_observed : float
The distance value between observed and reference samples.
distance_null : np.ndarray
A distribution of distance values under the null hypothesis.

Raises
------
ValueError
- If the number of number of observed samples exceeds the number of reference samples
- If the shapes of observed and reference samples do not match on dimensions besides the first one.
"""
num_observed: int = observed_samples.shape[0]
num_reference: int = reference_samples.shape[0]

if num_observed > num_reference:
raise ValueError(
f"Number of observed samples ({num_observed}) cannot exceed"
f"the number of reference samples ({num_reference}) for bootstrapping."
)
if observed_samples.shape[1:] != reference_samples.shape[1:]:
raise ValueError(
f"Expected observed and reference samples to have the same shape, "
f"but got {observed_samples.shape[1:]} != {reference_samples.shape[1:]}."
)

observed_samples_tensor: Tensor = convert_to_tensor(observed_samples, dtype="float32")
reference_samples_tensor: Tensor = convert_to_tensor(reference_samples, dtype="float32")

distance_null_samples: np.ndarray = np.zeros(num_null_samples, dtype=np.float64)
for i in range(num_null_samples):
bootstrap_idx: np.ndarray = np.random.randint(0, num_reference, size=num_observed)
bootstrap_samples: np.ndarray = reference_samples[bootstrap_idx]
bootstrap_samples_tensor: Tensor = convert_to_tensor(bootstrap_samples, dtype="float32")
distance_null_samples[i] = convert_to_numpy(comparison_fn(bootstrap_samples_tensor, reference_samples_tensor))

distance_observed_tensor: Tensor = comparison_fn(
observed_samples_tensor,
reference_samples_tensor,
)

distance_observed: float = float(convert_to_numpy(distance_observed_tensor))

return distance_observed, distance_null_samples


def summary_space_comparison(
observed_data: Mapping[str, np.ndarray],
reference_data: Mapping[str, np.ndarray],
approximator: ContinuousApproximator,
num_null_samples: int = 100,
comparison_fn: Callable = maximum_mean_discrepancy,
**kwargs,
) -> tuple[float, np.ndarray]:
"""Computes the distance between observed and reference data in the summary space and
generates a distribution of distance values under the null hypothesis to assess model misspecification.

By default, the Maximum Mean Discrepancy (MMD) is used as a distance function.

[1] M. Schmitt, P.-C. Bürkner, U. Köthe, and S. T. Radev, "Detecting model misspecification in amortized Bayesian
inference with neural networks," arXiv e-prints, Dec. 2021, Art. no. arXiv:2112.08866.
URL: https://arxiv.org/abs/2112.08866

Parameters
----------
observed_data : dict[str, np.ndarray]
Dictionary of observed data as NumPy arrays, which will be preprocessed by the approximators adapter and passed
through its summary network.
reference_data : dict[str, np.ndarray]
Dictionary of reference data as NumPy arrays, which will be preprocessed by the approximators adapter and passed
through its summary network.
approximator : ContinuousApproximator
An instance of :py:class:`~bayesflow.approximators.ContinuousApproximator` used to compute summary statistics
from the data.
num_null_samples : int, optional
Number of null samples to generate for hypothesis testing. Default is 100.
comparison_fn : Callable, optional
Distance function to compare the data in the summary space.
**kwargs : dict
Additional keyword arguments for the adapter and sampling process.

Returns
-------
distance_observed : float
The MMD value between observed and reference summaries.
distance_null : np.ndarray
A distribution of MMD values under the null hypothesis.

Raises
------
ValueError
If approximator is not an instance of ContinuousApproximator or does not have a summary network.
"""

if not isinstance(approximator, ContinuousApproximator):
raise ValueError("The approximator must be an instance of ContinuousApproximator.")

if not hasattr(approximator, "summary_network") or approximator.summary_network is None:
comparison_fn_name = (
"bayesflow.metrics.functional.maximum_mean_discrepancy"
if comparison_fn is maximum_mean_discrepancy
else comparison_fn.__name__
)
raise ValueError(
"The approximator must have a summary network. If you have manually crafted summary "
"statistics, or want to compare raw data and not summary statistics, please use the "
f"`bootstrap_comparison` function with `comparison_fn={comparison_fn_name}` on the respective arrays."
)
observed_summaries = convert_to_numpy(approximator.summaries(observed_data))
reference_summaries = convert_to_numpy(approximator.summaries(reference_data))

distance_observed, distance_null = bootstrap_comparison(
observed_samples=observed_summaries,
reference_samples=reference_summaries,
comparison_fn=comparison_fn,
num_null_samples=num_null_samples,
)

return distance_observed, distance_null
31 changes: 31 additions & 0 deletions tests/test_approximators/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,34 @@ def validation_dataset(batch_size, adapter, simulator):
num_batches = 2
data = simulator.sample((num_batches * batch_size,))
return OfflineDataset(data=data, adapter=adapter, batch_size=batch_size, workers=4, max_queue_size=num_batches)


@pytest.fixture()
def mean_std_summary_network():
from tests.utils import MeanStdSummaryNetwork

return MeanStdSummaryNetwork()


@pytest.fixture(params=["continuous_approximator", "point_approximator", "model_comparison_approximator"])
def approximator_with_summaries(request):
from bayesflow.adapters import Adapter

adapter = Adapter()
match request.param:
case "continuous_approximator":
from bayesflow.approximators import ContinuousApproximator

return ContinuousApproximator(adapter=adapter, inference_network=None, summary_network=None)
case "point_approximator":
from bayesflow.approximators import PointApproximator

return PointApproximator(adapter=adapter, inference_network=None, summary_network=None)
case "model_comparison_approximator":
from bayesflow.approximators import ModelComparisonApproximator

return ModelComparisonApproximator(
num_models=2, classifier_network=None, adapter=adapter, summary_network=None
)
case _:
raise ValueError("Invalid param for approximator class.")
23 changes: 23 additions & 0 deletions tests/test_approximators/test_summaries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
from tests.utils import assert_allclose
import keras


def test_valid_summaries(approximator_with_summaries, mean_std_summary_network, monkeypatch):
monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
summaries = approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
assert_allclose(summaries, keras.ops.stack([keras.ops.ones((2,)), keras.ops.zeros((2,))], axis=-1))


def test_no_summary_network(approximator_with_summaries, monkeypatch):
monkeypatch.setattr(approximator_with_summaries, "summary_network", None)

with pytest.raises(ValueError):
approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})


def test_no_summary_variables(approximator_with_summaries, mean_std_summary_network, monkeypatch):
monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)

with pytest.raises(ValueError):
approximator_with_summaries.summaries({})
14 changes: 14 additions & 0 deletions tests/test_diagnostics/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,17 @@ def history():
}

return h


@pytest.fixture()
def adapter():
from bayesflow.adapters import Adapter

return Adapter.create_default("parameters").rename("observables", "summary_variables")


@pytest.fixture()
def summary_network():
from tests.utils import MeanStdSummaryNetwork

return MeanStdSummaryNetwork()
Loading