Skip to content

Commit e23206d

Browse files
committed
rename summary_outputs to summaries
- add summaries function to ModelCommparisonApproximator as well - add tests for the approximator.summaries functions
1 parent 87c061c commit e23206d

File tree

8 files changed

+107
-17
lines changed

8 files changed

+107
-17
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,9 @@ def _sample(
400400
**filter_kwargs(kwargs, self.inference_network.sample),
401401
)
402402

403-
def summary_outputs(self, data: Mapping[str, np.ndarray], **kwargs):
403+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
404404
"""
405-
Computes the summary outputs of given data.
405+
Computes the summaries of given data.
406406
407407
The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.
408408
@@ -415,7 +415,7 @@ def summary_outputs(self, data: Mapping[str, np.ndarray], **kwargs):
415415
416416
Returns
417417
-------
418-
summary_outputs : np.ndarray
418+
summaries : np.ndarray
419419
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
420420
421421
Raises
@@ -425,13 +425,13 @@ def summary_outputs(self, data: Mapping[str, np.ndarray], **kwargs):
425425
by the summary network.
426426
"""
427427
if self.summary_network is None:
428-
raise ValueError("A summary network is required to compute summary outputs.")
428+
raise ValueError("A summary network is required to compute summeries.")
429429
data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
430430
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
431-
raise ValueError("Summary variables are required to compute summary outputs")
431+
raise ValueError("Summary variables are required to compute summaries.")
432432
summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
433-
summary_outputs = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
434-
return summary_outputs
433+
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
434+
return summaries
435435

436436
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]:
437437
"""

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,36 @@ def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tens
345345
output = self.logits_projector(output)
346346

347347
return output
348+
349+
def summaries(self, data: Mapping[str, np.ndarray], **kwargs):
350+
"""
351+
Computes the summaries of given data.
352+
353+
The `data` dictionary is preprocessed using the `adapter` and passed through the summary network.
354+
355+
Parameters
356+
----------
357+
data : Mapping[str, np.ndarray]
358+
Dictionary of data as NumPy arrays.
359+
**kwargs : dict
360+
Additional keyword arguments for the adapter and the summary network.
361+
362+
Returns
363+
-------
364+
summaries : np.ndarray
365+
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
366+
367+
Raises
368+
------
369+
ValueError
370+
If the approximator does not have a summary network, or the adapter does not produce the output required
371+
by the summary network.
372+
"""
373+
if self.summary_network is None:
374+
raise ValueError("A summary network is required to compute summaries.")
375+
data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
376+
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
377+
raise ValueError("Summary variables are required to compute summaries.")
378+
summary_variables = keras.ops.convert_to_tensor(data_adapted["summary_variables"])
379+
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
380+
return summaries

bayesflow/diagnostics/metrics/model_misspecification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ def summary_space_comparison(
142142
"statistics, or want to compare raw data and not summary statistics, please use the "
143143
f"`bootstrap_comparison` function with `comparison_fn={comparison_fn_name}` on the respective arrays."
144144
)
145-
observed_summaries = convert_to_numpy(approximator.summary_outputs(observed_data))
146-
reference_summaries = convert_to_numpy(approximator.summary_outputs(reference_data))
145+
observed_summaries = convert_to_numpy(approximator.summaries(observed_data))
146+
reference_summaries = convert_to_numpy(approximator.summaries(reference_data))
147147

148148
distance_observed, distance_null = bootstrap_comparison(
149149
observed_samples=observed_summaries,

tests/test_approximators/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,34 @@ def validation_dataset(batch_size, adapter, simulator):
163163
num_batches = 2
164164
data = simulator.sample((num_batches * batch_size,))
165165
return OfflineDataset(data=data, adapter=adapter, batch_size=batch_size, workers=4, max_queue_size=num_batches)
166+
167+
168+
@pytest.fixture()
169+
def mean_std_summary_network():
170+
from tests.utils import MeanStdSummaryNetwork
171+
172+
return MeanStdSummaryNetwork()
173+
174+
175+
@pytest.fixture(params=["continuous_approximator", "point_approximator", "model_comparison_approximator"])
176+
def approximator_with_summaries(request):
177+
from bayesflow.adapters import Adapter
178+
179+
adapter = Adapter()
180+
match request.param:
181+
case "continuous_approximator":
182+
from bayesflow.approximators import ContinuousApproximator
183+
184+
return ContinuousApproximator(adapter=adapter, inference_network=None, summary_network=None)
185+
case "point_approximator":
186+
from bayesflow.approximators import PointApproximator
187+
188+
return PointApproximator(adapter=adapter, inference_network=None, summary_network=None)
189+
case "model_comparison_approximator":
190+
from bayesflow.approximators import ModelComparisonApproximator
191+
192+
return ModelComparisonApproximator(
193+
num_models=2, classifier_network=None, adapter=adapter, summary_network=None
194+
)
195+
case _:
196+
raise ValueError("Invalid param for approximator class.")
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
from tests.utils import assert_allclose
3+
import keras
4+
5+
6+
def test_valid_summaries(approximator_with_summaries, mean_std_summary_network, monkeypatch):
7+
monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
8+
summaries = approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
9+
assert_allclose(summaries, keras.ops.stack([keras.ops.ones((2,)), keras.ops.zeros((2,))], axis=-1))
10+
11+
12+
def test_no_summary_network(approximator_with_summaries, monkeypatch):
13+
monkeypatch.setattr(approximator_with_summaries, "summary_network", None)
14+
15+
with pytest.raises(ValueError):
16+
approximator_with_summaries.summaries({"summary_variables": keras.ops.ones((2, 3))})
17+
18+
19+
def test_no_summary_variables(approximator_with_summaries, mean_std_summary_network, monkeypatch):
20+
monkeypatch.setattr(approximator_with_summaries, "summary_network", mean_std_summary_network)
21+
22+
with pytest.raises(ValueError):
23+
approximator_with_summaries.summaries({})

tests/test_diagnostics/conftest.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,6 @@ def adapter():
8989

9090
@pytest.fixture()
9191
def summary_network():
92-
from bayesflow.networks import SummaryNetwork
92+
from tests.utils import MeanStdSummaryNetwork
9393

94-
class DummySummaryNetwork(SummaryNetwork):
95-
def call(self, x):
96-
summary_outputs = keras.ops.stack([keras.ops.mean(x, axis=-1), keras.ops.std(x, axis=-1)], axis=-1)
97-
print("summary_outputs", summary_outputs)
98-
return summary_outputs
99-
100-
return DummySummaryNetwork()
94+
return MeanStdSummaryNetwork()

tests/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from .callbacks import *
33
from .check_combinations import *
44
from .jupyter import *
5+
from .networks import *
56
from .ops import *

tests/utils/networks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from bayesflow.networks import SummaryNetwork
2+
import keras
3+
4+
5+
class MeanStdSummaryNetwork(SummaryNetwork):
6+
def call(self, x):
7+
summary_outputs = keras.ops.stack([keras.ops.mean(x, axis=-1), keras.ops.std(x, axis=-1)], axis=-1)
8+
return summary_outputs

0 commit comments

Comments
 (0)