Skip to content

Commit d13da56

Browse files
Adding custom test quantities to more diagnostics (#586)
* Adding custom test quantities to diagnostics * Fix Code Style * run ruff again * Adding tests * adding one stupid blank line * again run ruff --------- Co-authored-by: Paul-Christian Bürkner <[email protected]>
1 parent b2a9ded commit d13da56

File tree

10 files changed

+306
-8
lines changed

10 files changed

+306
-8
lines changed

bayesflow/diagnostics/metrics/calibration_error.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import numpy as np
44

5-
from ...utils.dict_utils import dicts_to_arrays
5+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
66

77

88
def calibration_error(
99
estimates: Mapping[str, np.ndarray] | np.ndarray,
1010
targets: Mapping[str, np.ndarray] | np.ndarray,
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
13+
test_quantities: dict[str, Callable] = None,
1314
resolution: int = 20,
1415
aggregation: Callable = np.median,
1516
min_quantile: float = 0.005,
@@ -32,6 +33,18 @@ def calibration_error(
3233
By default, select all keys.
3334
variable_names : Sequence[str], optional (default = None)
3435
Optional variable names to show in the output.
36+
test_quantities : dict or None, optional, default: None
37+
A dict that maps plot titles to functions that compute
38+
test quantities based on estimate/target draws.
39+
40+
The dict keys are automatically added to ``variable_keys``
41+
and ``variable_names``.
42+
Test quantity functions are expected to accept a dict of draws with
43+
shape ``(batch_size, ...)`` as the first (typically only)
44+
positional argument and return an NumPy array of shape
45+
``(batch_size,)``.
46+
The functions do not have to deal with an additional
47+
sample dimension, as appropriate reshaping is done internally.
3548
resolution : int, optional, default: 20
3649
The number of credibility intervals (CIs) to consider
3750
aggregation : callable or None, optional, default: np.median
@@ -55,6 +68,19 @@ def calibration_error(
5568
The (inferred) variable names.
5669
"""
5770

71+
if test_quantities is not None:
72+
updated_data = compute_test_quantities(
73+
targets=targets,
74+
estimates=estimates,
75+
variable_keys=variable_keys,
76+
variable_names=variable_names,
77+
test_quantities=test_quantities,
78+
)
79+
variable_names = updated_data["variable_names"]
80+
variable_keys = updated_data["variable_keys"]
81+
estimates = updated_data["estimates"]
82+
targets = updated_data["targets"]
83+
5884
samples = dicts_to_arrays(
5985
estimates=estimates,
6086
targets=targets,

bayesflow/diagnostics/metrics/calibration_log_gamma.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
from collections.abc import Mapping, Sequence
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import numpy as np
44
from scipy.stats import binom
55

6-
from ...utils.dict_utils import dicts_to_arrays
6+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
77

88

99
def calibration_log_gamma(
1010
estimates: Mapping[str, np.ndarray] | np.ndarray,
1111
targets: Mapping[str, np.ndarray] | np.ndarray,
1212
variable_keys: Sequence[str] = None,
1313
variable_names: Sequence[str] = None,
14+
test_quantities: dict[str, Callable] = None,
1415
num_null_draws: int = 1000,
1516
quantile: float = 0.05,
1617
):
@@ -41,6 +42,18 @@ def calibration_log_gamma(
4142
By default, select all keys.
4243
variable_names : Sequence[str], optional (default = None)
4344
Optional variable names to show in the output.
45+
test_quantities : dict or None, optional, default: None
46+
A dict that maps plot titles to functions that compute
47+
test quantities based on estimate/target draws.
48+
49+
The dict keys are automatically added to ``variable_keys``
50+
and ``variable_names``.
51+
Test quantity functions are expected to accept a dict of draws with
52+
shape ``(batch_size, ...)`` as the first (typically only)
53+
positional argument and return an NumPy array of shape
54+
``(batch_size,)``.
55+
The functions do not have to deal with an additional
56+
sample dimension, as appropriate reshaping is done internally.
4457
quantile : float in (0, 1), optional, default 0.05
4558
The quantile from the null distribution to be used as a threshold.
4659
A lower quantile increases sensitivity to deviations from uniformity.
@@ -57,6 +70,21 @@ def calibration_log_gamma(
5770
- "variable_names" : str
5871
The (inferred) variable names.
5972
"""
73+
74+
# Optionally, compute and prepend test quantities from draws
75+
if test_quantities is not None:
76+
updated_data = compute_test_quantities(
77+
targets=targets,
78+
estimates=estimates,
79+
variable_keys=variable_keys,
80+
variable_names=variable_names,
81+
test_quantities=test_quantities,
82+
)
83+
variable_names = updated_data["variable_names"]
84+
variable_keys = updated_data["variable_keys"]
85+
estimates = updated_data["estimates"]
86+
targets = updated_data["targets"]
87+
6088
samples = dicts_to_arrays(
6189
estimates=estimates,
6290
targets=targets,

bayesflow/diagnostics/metrics/posterior_contraction.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import numpy as np
44

5-
from ...utils.dict_utils import dicts_to_arrays
5+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
66

77

88
def posterior_contraction(
99
estimates: Mapping[str, np.ndarray] | np.ndarray,
1010
targets: Mapping[str, np.ndarray] | np.ndarray,
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
13+
test_quantities: dict[str, Callable] = None,
1314
aggregation: Callable | None = np.median,
1415
) -> dict[str, any]:
1516
"""
@@ -27,6 +28,18 @@ def posterior_contraction(
2728
By default, select all keys.
2829
variable_names : Sequence[str], optional (default = None)
2930
Optional variable names to show in the output.
31+
test_quantities : dict or None, optional, default: None
32+
A dict that maps plot titles to functions that compute
33+
test quantities based on estimate/target draws.
34+
35+
The dict keys are automatically added to ``variable_keys``
36+
and ``variable_names``.
37+
Test quantity functions are expected to accept a dict of draws with
38+
shape ``(batch_size, ...)`` as the first (typically only)
39+
positional argument and return an NumPy array of shape
40+
``(batch_size,)``.
41+
The functions do not have to deal with an additional
42+
sample dimension, as appropriate reshaping is done internally.
3043
aggregation : callable or None, optional (default = np.median)
3144
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
3245
If None is provided, the individual values are returned.
@@ -50,6 +63,20 @@ def posterior_contraction(
5063
indicate low contraction.
5164
"""
5265

66+
# Optionally, compute and prepend test quantities from draws
67+
if test_quantities is not None:
68+
updated_data = compute_test_quantities(
69+
targets=targets,
70+
estimates=estimates,
71+
variable_keys=variable_keys,
72+
variable_names=variable_names,
73+
test_quantities=test_quantities,
74+
)
75+
variable_names = updated_data["variable_names"]
76+
variable_keys = updated_data["variable_keys"]
77+
estimates = updated_data["estimates"]
78+
targets = updated_data["targets"]
79+
5380
samples = dicts_to_arrays(
5481
estimates=estimates,
5582
targets=targets,

bayesflow/diagnostics/metrics/root_mean_squared_error.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import numpy as np
44

5-
from ...utils.dict_utils import dicts_to_arrays
5+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
66

77

88
def root_mean_squared_error(
99
estimates: Mapping[str, np.ndarray] | np.ndarray,
1010
targets: Mapping[str, np.ndarray] | np.ndarray,
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
13+
test_quantities: dict[str, Callable] = None,
1314
normalize: str | None = "range",
1415
aggregation: Callable = np.median,
1516
) -> dict[str, any]:
@@ -28,6 +29,18 @@ def root_mean_squared_error(
2829
By default, select all keys.
2930
variable_names : Sequence[str], optional (default = None)
3031
Optional variable names to show in the output.
32+
test_quantities : dict or None, optional, default: None
33+
A dict that maps plot titles to functions that compute
34+
test quantities based on estimate/target draws.
35+
36+
The dict keys are automatically added to ``variable_keys``
37+
and ``variable_names``.
38+
Test quantity functions are expected to accept a dict of draws with
39+
shape ``(batch_size, ...)`` as the first (typically only)
40+
positional argument and return an NumPy array of shape
41+
``(batch_size,)``.
42+
The functions do not have to deal with an additional
43+
sample dimension, as appropriate reshaping is done internally.
3144
normalize : str or None, optional (default = "range")
3245
Whether to normalize the RMSE using statistics of the prior samples.
3346
Possible options are ("mean", "range", "median", "iqr", "std", None)
@@ -52,6 +65,20 @@ def root_mean_squared_error(
5265
The (inferred) variable names.
5366
"""
5467

68+
# Optionally, compute and prepend test quantities from draws
69+
if test_quantities is not None:
70+
updated_data = compute_test_quantities(
71+
targets=targets,
72+
estimates=estimates,
73+
variable_keys=variable_keys,
74+
variable_names=variable_names,
75+
test_quantities=test_quantities,
76+
)
77+
variable_names = updated_data["variable_names"]
78+
variable_keys = updated_data["variable_keys"]
79+
estimates = updated_data["estimates"]
80+
targets = updated_data["targets"]
81+
5582
samples = dicts_to_arrays(
5683
estimates=estimates,
5784
targets=targets,

bayesflow/diagnostics/plots/calibration_histogram.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Sequence, Mapping
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import matplotlib.pyplot as plt
44
import numpy as np
@@ -8,13 +8,15 @@
88

99
from bayesflow.utils import logging
1010
from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
11+
from bayesflow.utils.dict_utils import compute_test_quantities
1112

1213

1314
def calibration_histogram(
1415
estimates: Mapping[str, np.ndarray] | np.ndarray,
1516
targets: Mapping[str, np.ndarray] | np.ndarray,
1617
variable_keys: Sequence[str] = None,
1718
variable_names: Sequence[str] = None,
19+
test_quantities: dict[str, Callable] = None,
1820
figsize: Sequence[float] = None,
1921
num_bins: int = 10,
2022
binomial_interval: float = 0.99,
@@ -46,6 +48,18 @@ def calibration_histogram(
4648
By default, select all keys.
4749
variable_names : list or None, optional, default: None
4850
The parameter names for nice plot titles. Inferred if None
51+
test_quantities : dict or None, optional, default: None
52+
A dict that maps plot titles to functions that compute
53+
test quantities based on estimate/target draws.
54+
55+
The dict keys are automatically added to ``variable_keys``
56+
and ``variable_names``.
57+
Test quantity functions are expected to accept a dict of draws with
58+
shape ``(batch_size, ...)`` as the first (typically only)
59+
positional argument and return an NumPy array of shape
60+
``(batch_size,)``.
61+
The functions do not have to deal with an additional
62+
sample dimension, as appropriate reshaping is done internally.
4963
figsize : tuple or None, optional, default : None
5064
The figure size passed to the matplotlib constructor. Inferred if None
5165
num_bins : int, optional, default: 10
@@ -75,6 +89,20 @@ def calibration_histogram(
7589
If there is a deviation form the expected shapes of `estimates` and `targets`.
7690
"""
7791

92+
# Optionally, compute and prepend test quantities from draws
93+
if test_quantities is not None:
94+
updated_data = compute_test_quantities(
95+
targets=targets,
96+
estimates=estimates,
97+
variable_keys=variable_keys,
98+
variable_names=variable_names,
99+
test_quantities=test_quantities,
100+
)
101+
variable_names = updated_data["variable_names"]
102+
variable_keys = updated_data["variable_keys"]
103+
estimates = updated_data["estimates"]
104+
targets = updated_data["targets"]
105+
78106
plot_data = prepare_plot_data(
79107
estimates=estimates,
80108
targets=targets,

bayesflow/diagnostics/plots/coverage.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from collections.abc import Sequence, Mapping
1+
from collections.abc import Callable, Sequence, Mapping
22

33
import matplotlib.pyplot as plt
44
import numpy as np
55

66
from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots, compute_empirical_coverage
7+
from bayesflow.utils.dict_utils import compute_test_quantities
78

89

910
def coverage(
@@ -12,6 +13,7 @@ def coverage(
1213
difference: bool = False,
1314
variable_keys: Sequence[str] = None,
1415
variable_names: Sequence[str] = None,
16+
test_quantities: dict[str, Callable] = None,
1517
figsize: Sequence[int] = None,
1618
label_fontsize: int = 16,
1719
legend_fontsize: int = 14,
@@ -50,6 +52,18 @@ def coverage(
5052
By default, select all keys.
5153
variable_names : list or None, optional, default: None
5254
The parameter names for nice plot titles. Inferred if None
55+
test_quantities : dict or None, optional, default: None
56+
A dict that maps plot titles to functions that compute
57+
test quantities based on estimate/target draws.
58+
59+
The dict keys are automatically added to ``variable_keys``
60+
and ``variable_names``.
61+
Test quantity functions are expected to accept a dict of draws with
62+
shape ``(batch_size, ...)`` as the first (typically only)
63+
positional argument and return an NumPy array of shape
64+
``(batch_size,)``.
65+
The functions do not have to deal with an additional
66+
sample dimension, as appropriate reshaping is done internally.
5367
figsize : tuple or None, optional, default: None
5468
The figure size passed to the matplotlib constructor. Inferred if None.
5569
label_fontsize : int, optional, default: 16
@@ -80,6 +94,20 @@ def coverage(
8094
8195
"""
8296

97+
# Optionally, compute and prepend test quantities from draws
98+
if test_quantities is not None:
99+
updated_data = compute_test_quantities(
100+
targets=targets,
101+
estimates=estimates,
102+
variable_keys=variable_keys,
103+
variable_names=variable_names,
104+
test_quantities=test_quantities,
105+
)
106+
variable_names = updated_data["variable_names"]
107+
variable_keys = updated_data["variable_keys"]
108+
estimates = updated_data["estimates"]
109+
targets = updated_data["targets"]
110+
83111
# Gather plot data and metadata into a dictionary
84112
plot_data = prepare_plot_data(
85113
estimates=estimates,

0 commit comments

Comments
 (0)