Skip to content

Commit b18504b

Browse files
committed
refactor summary space distance function
- remove somewhat redundant mmd_comparison_from_summaries function - rename mmd_comparison to the more general summary_space_comparison, with configurable distance function (default MMD) - only allow calling summary_space_comparison when we can obtain the summary variables directly from the approximator. For all other use cases, directly refer to bootstrap_comparison - update tests to reflect those changes - remove redundant docstrings from the module
1 parent 5f2625a commit b18504b

File tree

5 files changed

+166
-218
lines changed

5 files changed

+166
-218
lines changed

bayesflow/diagnostics/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
A collection of plotting utilities and metrics for evaluating trained :py:class:`~bayesflow.workflows.Workflow`\ s.
33
"""
44

5-
from .metrics import root_mean_squared_error, calibration_error, posterior_contraction
5+
from .metrics import (
6+
bootstrap_comparison,
7+
calibration_error,
8+
posterior_contraction,
9+
summary_space_comparison,
10+
)
611

712
from .plots import (
813
calibration_ecdf,

bayesflow/diagnostics/metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from .root_mean_squared_error import root_mean_squared_error
44
from .expected_calibration_error import expected_calibration_error
55
from .classifier_two_sample_test import classifier_two_sample_test
6-
from .mmd_hypothesis_test import bootstrap_comparison, mmd_comparison, mmd_comparison_from_summaries
6+
from .mmd_hypothesis_test import bootstrap_comparison, summary_space_comparison
Lines changed: 54 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,34 @@
11
"""
22
This module provides functions for computing distances between observation samples and reference samples with distance
33
distributions within the reference samples for hypothesis testing.
4-
5-
Functions:
6-
----------
7-
- bootstrap_comparison: Computes distance between observed and reference samples and generates a distribution of null
8-
sample distances by bootstrapping for hypothesis testing.
9-
- mmd_comparison_from_summaries: Computes the Maximum Mean Discrepancy (MMD) between observed and reference summaries
10-
and generates a distribution of MMD values under the null hypothesis to assess model misspecification.
11-
- mmd_comparison: Computes the Maximum Mean Discrepancy (MMD) between observed and reference data and generates a
12-
distribution of MMD values under the null hypothesis to assess model misspecification.
13-
14-
Dependencies:
15-
-------------
16-
- numpy: For numerical operations.
17-
- keras.ops: For converting data to numpy and tensor formats.
18-
- bayesflow.networks: Provides the `SummaryNetwork` class for extracting summary statistics.
19-
- bayesflow.approximators: Provides the `Approximator` class for extracting summary statistics.
20-
- bayesflow.metrics: Provides the `maximum_mean_discrepancy` function for computing the MMD.
214
"""
225

23-
import typing
6+
from collections.abc import Mapping, Callable
247

258
import numpy as np
269
from keras.ops import convert_to_numpy, convert_to_tensor
2710

2811
from bayesflow.approximators import ContinuousApproximator
2912
from bayesflow.metrics.functional import maximum_mean_discrepancy
30-
from bayesflow.networks import SummaryNetwork
3113
from bayesflow.types import Tensor
3214

3315

3416
def bootstrap_comparison(
3517
observed_samples: np.ndarray,
3618
reference_samples: np.ndarray,
37-
comparison_fn: typing.Callable[[Tensor, Tensor], Tensor],
19+
comparison_fn: Callable[[Tensor, Tensor], Tensor],
3820
num_null_samples: int = 100,
3921
) -> tuple[float, np.ndarray]:
40-
"""Compute distance between observed and reference samples and generated a distribution of null sample distances by
41-
bootstrapping for hypothesis testing.
22+
"""Computes the distance between observed and reference samples and generates a distribution of null sample
23+
distances by bootstrapping for hypothesis testing.
4224
4325
Parameters
4426
----------
4527
observed_samples : np.ndarray)
4628
Observed samples, shape (num_observed, ...).
4729
reference_samples : np.ndarray
4830
Reference samples, shape (num_reference, ...).
49-
comparison_fn : typing.Callable[[Tensor, Tensor], Tensor]
31+
comparison_fn : Callable[[Tensor, Tensor], Tensor]
5032
Function to compute the distance metric.
5133
num_null_samples : int
5234
Number of null samples to generate for hypothesis testing. Default is 100.
@@ -98,108 +80,76 @@ def bootstrap_comparison(
9880
return distance_observed, distance_null_samples
9981

10082

101-
def mmd_comparison_from_summaries(
102-
observed_summaries: np.ndarray,
103-
reference_summaries: np.ndarray,
83+
def summary_space_comparison(
84+
observed_data: Mapping[str, np.ndarray],
85+
reference_data: Mapping[str, np.ndarray],
86+
approximator: ContinuousApproximator,
10487
num_null_samples: int = 100,
88+
comparison_fn: Callable = maximum_mean_discrepancy,
89+
**kwargs,
10590
) -> tuple[float, np.ndarray]:
106-
"""Computes the Maximum Mean Discrepancy (MMD) between observed and reference summaries and generates a distribution
107-
of MMD values under the null hypothesis to assess model misspecification.
108-
109-
[1] M. Schmitt, P.-C. Bürkner, U. Köthe, and S. T. Radev, "Detecting model misspecification in amortized Bayesian
110-
inference with neural networks," arXiv e-prints, Dec. 2021, Art. no. arXiv:2112.08866.
111-
URL: https://arxiv.org/abs/2112.08866
91+
"""Computes the distance between observed and reference data in the summary space and
92+
generates a distribution of distance values under the null hypothesis to assess model misspecification.
11293
113-
114-
Parameters
115-
----------
116-
observed_summary : np.ndarray
117-
Summary statistics of observed data, shape (num_observed, ...).
118-
reference_summary : np.ndarray
119-
Summary statistics of reference data, shape (num_reference, ...).
120-
num_null_samples : int
121-
Number of null samples to generate for hypothesis testing. Default is 100.
122-
123-
Returns
124-
-------
125-
mmd_observed : float
126-
The MMD value between observed and reference summaries.
127-
mmd_null : np.ndarray
128-
A distribution of MMD values under the null hypothesis.
129-
"""
130-
mmd_observed, mmd_null_samples = bootstrap_comparison(
131-
observed_samples=observed_summaries,
132-
reference_samples=reference_summaries,
133-
comparison_fn=maximum_mean_discrepancy,
134-
num_null_samples=num_null_samples,
135-
)
136-
137-
return mmd_observed, mmd_null_samples
138-
139-
140-
def mmd_comparison(
141-
observed_data: np.ndarray,
142-
reference_data: np.ndarray,
143-
approximator: ContinuousApproximator | SummaryNetwork,
144-
num_null_samples: int = 100,
145-
) -> tuple[float, np.ndarray]:
146-
"""Computes the Maximum Mean Discrepancy (MMD) between observed and reference data and generates a distribution of
147-
MMD values under the null hypothesis to assess model misspecification.
94+
By default, the Maximum Mean Discrepancy (MMD) is used as a distance function.
14895
14996
[1] M. Schmitt, P.-C. Bürkner, U. Köthe, and S. T. Radev, "Detecting model misspecification in amortized Bayesian
15097
inference with neural networks," arXiv e-prints, Dec. 2021, Art. no. arXiv:2112.08866.
15198
URL: https://arxiv.org/abs/2112.08866
15299
153-
154100
Parameters
155101
----------
156-
observed_data : np.ndarray
157-
Observed data, shape (num_observed, ...).
158-
reference_data : np.ndarray
159-
Reference data, shape (num_reference, ...).
160-
approximator : ContinuousApproximator or SummaryNetwork
161-
An instance of the ContinuousApproximator or SummaryNetwork class use to extract summary statistics from data.
162-
num_null_samples : int
102+
observed_data : dict[str, np.ndarray]
103+
Dictionary of observed data as NumPy arrays, which will be preprocessed by the approximators adapter and passed
104+
through its summary network.
105+
reference_data : dict[str, np.ndarray]
106+
Dictionary of reference data as NumPy arrays, which will be preprocessed by the approximators adapter and passed
107+
through its summary network.
108+
approximator : ContinuousApproximator
109+
An instance of :py:class:`~bayesflow.approximators.ContinuousApproximator` used to compute summary statistics
110+
from the data.
111+
num_null_samples : int, optional
163112
Number of null samples to generate for hypothesis testing. Default is 100.
113+
comparison_fn : Callable, optional
114+
Distance function to compare the data in the summary space.
115+
**kwargs : dict
116+
Additional keyword arguments for the adapter and sampling process.
164117
165118
Returns
166119
-------
167-
mmd_observed : float
168-
The MMD value between observed and reference data.
169-
mmd_null : np.ndarray
120+
distance_observed : float
121+
The MMD value between observed and reference summaries.
122+
distance_null : np.ndarray
170123
A distribution of MMD values under the null hypothesis.
171124
172-
Raises:
125+
Raises
173126
------
174127
ValueError
175-
- If the shapes of observed and reference data do not match on dimensions besides the first one.
176-
- If approximator is not an instance of ContinuousApproximator or SummaryNetwork.
128+
If approximator is not an instance of ContinuousApproximator or does not have a summary network.
177129
"""
178-
if observed_data.shape[1:] != reference_data.shape[1:]:
130+
131+
if not isinstance(approximator, ContinuousApproximator):
132+
raise ValueError("The approximator must be an instance of ContinuousApproximator.")
133+
134+
if not hasattr(approximator, "summary_network") or approximator.summary_network is None:
135+
comparison_fn_name = (
136+
"bayesflow.metrics.functional.maximum_mean_discrepancy"
137+
if comparison_fn is maximum_mean_discrepancy
138+
else comparison_fn.__name__
139+
)
179140
raise ValueError(
180-
f"Expected observed and reference data to have the same shape, "
181-
f"but got {observed_data.shape[1:]} != {reference_data.shape[1:]}."
141+
"The approximator must have a summary network. If you have manually crafted summary "
142+
"statistics, or want to compare raw data and not summary statistics, please use the "
143+
f"`bootstrap_comparison` function with `comparison_fn={comparison_fn_name}` on the respective arrays."
182144
)
145+
observed_summaries = convert_to_numpy(approximator.summary_outputs(observed_data))
146+
reference_summaries = convert_to_numpy(approximator.summary_outputs(reference_data))
183147

184-
if isinstance(approximator, ContinuousApproximator):
185-
if approximator.summary_network is not None:
186-
observed_data_tensor: Tensor = convert_to_tensor(observed_data)
187-
reference_data_tensor: Tensor = convert_to_tensor(reference_data)
188-
observed_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(observed_data_tensor))
189-
reference_summaries: np.ndarray = convert_to_numpy(approximator.summary_network(reference_data_tensor))
190-
else:
191-
observed_summaries: np.ndarray = observed_data
192-
reference_summaries: np.ndarray = reference_data
193-
elif isinstance(approximator, SummaryNetwork):
194-
observed_data_tensor: Tensor = convert_to_tensor(observed_data)
195-
reference_data_tensor: Tensor = convert_to_tensor(reference_data)
196-
observed_summaries: np.ndarray = convert_to_numpy(approximator(observed_data_tensor))
197-
reference_summaries: np.ndarray = convert_to_numpy(approximator(reference_data_tensor))
198-
else:
199-
raise ValueError("The approximator must be an instance of ContinuousApproximator or SummaryNetwork.")
200-
201-
mmd_observed, mmd_null = mmd_comparison_from_summaries(
202-
observed_summaries, reference_summaries, num_null_samples=num_null_samples
148+
distance_observed, distance_null = bootstrap_comparison(
149+
observed_samples=observed_summaries,
150+
reference_samples=reference_summaries,
151+
comparison_fn=comparison_fn,
152+
num_null_samples=num_null_samples,
203153
)
204154

205-
return mmd_observed, mmd_null
155+
return distance_observed, distance_null

tests/test_diagnostics/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,23 @@ def history():
7878
}
7979

8080
return h
81+
82+
83+
@pytest.fixture()
84+
def adapter():
85+
from bayesflow.adapters import Adapter
86+
87+
return Adapter.create_default("parameters").rename("observables", "summary_variables")
88+
89+
90+
@pytest.fixture()
91+
def summary_network():
92+
from bayesflow.networks import SummaryNetwork
93+
94+
class DummySummaryNetwork(SummaryNetwork):
95+
def call(self, x):
96+
summary_outputs = keras.ops.stack([keras.ops.mean(x, axis=-1), keras.ops.std(x, axis=-1)], axis=-1)
97+
print("summary_outputs", summary_outputs)
98+
return summary_outputs
99+
100+
return DummySummaryNetwork()

0 commit comments

Comments
 (0)