Skip to content

Commit 87810b1

Browse files
authored
Merge branch 'dev' into dev2
2 parents ed5f9c3 + d13da56 commit 87810b1

40 files changed

+1900
-1049
lines changed

.github/workflows/tests.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
name: Multi-Backend Tests
32

43
on:
@@ -16,15 +15,14 @@ defaults:
1615
run:
1716
shell: bash
1817

19-
2018
jobs:
2119
test:
2220
name: Run Multi-Backend Tests
2321

2422
strategy:
2523
matrix:
2624
os: [ubuntu-latest, windows-latest]
27-
python-version: ["3.10"] # we usually only need to test the oldest python version
25+
python-version: ["3.10"] # we usually only need to test the oldest python version
2826
backend: ["jax", "tensorflow", "torch"]
2927

3028
runs-on: ${{ matrix.os }}

CHANGELOG.rst

Lines changed: 0 additions & 92 deletions
This file was deleted.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ complex to be described analytically.
5151

5252
## Install
5353

54-
We currently support Python 3.10 to 3.12. You can install the latest stable version from PyPI using:
54+
We currently support Python 3.10 to 3.13. You can install the latest stable version from PyPI using:
5555

5656
```bash
5757
pip install "bayesflow>=2.0"

bayesflow/approximators/continuous_approximator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,13 @@ def _sample(
535535
inference_conditions = keras.ops.broadcast_to(
536536
inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:])
537537
)
538-
batch_shape = keras.ops.shape(inference_conditions)[:-1]
538+
539+
if hasattr(self.inference_network, "base_distribution"):
540+
target_shape_len = len(self.inference_network.base_distribution.dims)
541+
else:
542+
# point approximator has no base_distribution
543+
target_shape_len = 1
544+
batch_shape = keras.ops.shape(inference_conditions)[:-target_shape_len]
539545
else:
540546
batch_shape = (num_samples,)
541547

@@ -564,11 +570,11 @@ def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
564570
if self.summary_network is None:
565571
raise ValueError("A summary network is required to compute summaries.")
566572

567-
data_adapted = self.adapter(data, strict=False, **kwargs)
573+
data_adapted = self._prepare_data(data, **kwargs)
568574
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
569575
raise ValueError("Summary variables are required to compute summaries.")
570576

571-
summary_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, data_adapted["summary_variables"])
577+
summary_variables = data_adapted["summary_variables"]
572578
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
573579
summaries = keras.ops.convert_to_numpy(summaries)
574580

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
433433
raise ValueError("Summary variables are required to compute summaries.")
434434

435435
summary_variables = keras.tree.map_structure(keras.ops.convert_to_tensor, data_adapted["summary_variables"])
436+
437+
if "summary_variables" in self.standardize:
438+
summary_variables = self.standardize_layers["summary_variables"](summary_variables)
439+
436440
summaries = self.summary_network(summary_variables, **filter_kwargs(kwargs, self.summary_network.call))
437441
summaries = keras.ops.convert_to_numpy(summaries)
438442

bayesflow/diagnostics/metrics/calibration_error.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import numpy as np
44

5-
from ...utils.dict_utils import dicts_to_arrays
5+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
66

77

88
def calibration_error(
99
estimates: Mapping[str, np.ndarray] | np.ndarray,
1010
targets: Mapping[str, np.ndarray] | np.ndarray,
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
13+
test_quantities: dict[str, Callable] = None,
1314
resolution: int = 20,
1415
aggregation: Callable = np.median,
1516
min_quantile: float = 0.005,
@@ -32,6 +33,18 @@ def calibration_error(
3233
By default, select all keys.
3334
variable_names : Sequence[str], optional (default = None)
3435
Optional variable names to show in the output.
36+
test_quantities : dict or None, optional, default: None
37+
A dict that maps plot titles to functions that compute
38+
test quantities based on estimate/target draws.
39+
40+
The dict keys are automatically added to ``variable_keys``
41+
and ``variable_names``.
42+
Test quantity functions are expected to accept a dict of draws with
43+
shape ``(batch_size, ...)`` as the first (typically only)
44+
positional argument and return an NumPy array of shape
45+
``(batch_size,)``.
46+
The functions do not have to deal with an additional
47+
sample dimension, as appropriate reshaping is done internally.
3548
resolution : int, optional, default: 20
3649
The number of credibility intervals (CIs) to consider
3750
aggregation : callable or None, optional, default: np.median
@@ -55,6 +68,19 @@ def calibration_error(
5568
The (inferred) variable names.
5669
"""
5770

71+
if test_quantities is not None:
72+
updated_data = compute_test_quantities(
73+
targets=targets,
74+
estimates=estimates,
75+
variable_keys=variable_keys,
76+
variable_names=variable_names,
77+
test_quantities=test_quantities,
78+
)
79+
variable_names = updated_data["variable_names"]
80+
variable_keys = updated_data["variable_keys"]
81+
estimates = updated_data["estimates"]
82+
targets = updated_data["targets"]
83+
5884
samples = dicts_to_arrays(
5985
estimates=estimates,
6086
targets=targets,

bayesflow/diagnostics/metrics/calibration_log_gamma.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
from collections.abc import Mapping, Sequence
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import numpy as np
44
from scipy.stats import binom
55

6-
from ...utils.dict_utils import dicts_to_arrays
6+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
77

88

99
def calibration_log_gamma(
1010
estimates: Mapping[str, np.ndarray] | np.ndarray,
1111
targets: Mapping[str, np.ndarray] | np.ndarray,
1212
variable_keys: Sequence[str] = None,
1313
variable_names: Sequence[str] = None,
14+
test_quantities: dict[str, Callable] = None,
1415
num_null_draws: int = 1000,
1516
quantile: float = 0.05,
1617
):
@@ -41,6 +42,18 @@ def calibration_log_gamma(
4142
By default, select all keys.
4243
variable_names : Sequence[str], optional (default = None)
4344
Optional variable names to show in the output.
45+
test_quantities : dict or None, optional, default: None
46+
A dict that maps plot titles to functions that compute
47+
test quantities based on estimate/target draws.
48+
49+
The dict keys are automatically added to ``variable_keys``
50+
and ``variable_names``.
51+
Test quantity functions are expected to accept a dict of draws with
52+
shape ``(batch_size, ...)`` as the first (typically only)
53+
positional argument and return an NumPy array of shape
54+
``(batch_size,)``.
55+
The functions do not have to deal with an additional
56+
sample dimension, as appropriate reshaping is done internally.
4457
quantile : float in (0, 1), optional, default 0.05
4558
The quantile from the null distribution to be used as a threshold.
4659
A lower quantile increases sensitivity to deviations from uniformity.
@@ -57,6 +70,21 @@ def calibration_log_gamma(
5770
- "variable_names" : str
5871
The (inferred) variable names.
5972
"""
73+
74+
# Optionally, compute and prepend test quantities from draws
75+
if test_quantities is not None:
76+
updated_data = compute_test_quantities(
77+
targets=targets,
78+
estimates=estimates,
79+
variable_keys=variable_keys,
80+
variable_names=variable_names,
81+
test_quantities=test_quantities,
82+
)
83+
variable_names = updated_data["variable_names"]
84+
variable_keys = updated_data["variable_keys"]
85+
estimates = updated_data["estimates"]
86+
targets = updated_data["targets"]
87+
6088
samples = dicts_to_arrays(
6189
estimates=estimates,
6290
targets=targets,

bayesflow/diagnostics/metrics/posterior_contraction.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import numpy as np
44

5-
from ...utils.dict_utils import dicts_to_arrays
5+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
66

77

88
def posterior_contraction(
99
estimates: Mapping[str, np.ndarray] | np.ndarray,
1010
targets: Mapping[str, np.ndarray] | np.ndarray,
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
13+
test_quantities: dict[str, Callable] = None,
1314
aggregation: Callable | None = np.median,
1415
) -> dict[str, any]:
1516
"""
@@ -27,6 +28,18 @@ def posterior_contraction(
2728
By default, select all keys.
2829
variable_names : Sequence[str], optional (default = None)
2930
Optional variable names to show in the output.
31+
test_quantities : dict or None, optional, default: None
32+
A dict that maps plot titles to functions that compute
33+
test quantities based on estimate/target draws.
34+
35+
The dict keys are automatically added to ``variable_keys``
36+
and ``variable_names``.
37+
Test quantity functions are expected to accept a dict of draws with
38+
shape ``(batch_size, ...)`` as the first (typically only)
39+
positional argument and return an NumPy array of shape
40+
``(batch_size,)``.
41+
The functions do not have to deal with an additional
42+
sample dimension, as appropriate reshaping is done internally.
3043
aggregation : callable or None, optional (default = np.median)
3144
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
3245
If None is provided, the individual values are returned.
@@ -50,6 +63,20 @@ def posterior_contraction(
5063
indicate low contraction.
5164
"""
5265

66+
# Optionally, compute and prepend test quantities from draws
67+
if test_quantities is not None:
68+
updated_data = compute_test_quantities(
69+
targets=targets,
70+
estimates=estimates,
71+
variable_keys=variable_keys,
72+
variable_names=variable_names,
73+
test_quantities=test_quantities,
74+
)
75+
variable_names = updated_data["variable_names"]
76+
variable_keys = updated_data["variable_keys"]
77+
estimates = updated_data["estimates"]
78+
targets = updated_data["targets"]
79+
5380
samples = dicts_to_arrays(
5481
estimates=estimates,
5582
targets=targets,

0 commit comments

Comments
 (0)