Skip to content

Commit daa32a1

Browse files
committed
Merge branch 'fix-stable-consistency' into compositional_sampling_diffusion
2 parents 983cb8d + d13da56 commit daa32a1

17 files changed

+717
-171
lines changed

bayesflow/diagnostics/metrics/calibration_error.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import numpy as np
44

5-
from ...utils.dict_utils import dicts_to_arrays
5+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
66

77

88
def calibration_error(
99
estimates: Mapping[str, np.ndarray] | np.ndarray,
1010
targets: Mapping[str, np.ndarray] | np.ndarray,
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
13+
test_quantities: dict[str, Callable] = None,
1314
resolution: int = 20,
1415
aggregation: Callable = np.median,
1516
min_quantile: float = 0.005,
@@ -32,6 +33,18 @@ def calibration_error(
3233
By default, select all keys.
3334
variable_names : Sequence[str], optional (default = None)
3435
Optional variable names to show in the output.
36+
test_quantities : dict or None, optional, default: None
37+
A dict that maps plot titles to functions that compute
38+
test quantities based on estimate/target draws.
39+
40+
The dict keys are automatically added to ``variable_keys``
41+
and ``variable_names``.
42+
Test quantity functions are expected to accept a dict of draws with
43+
shape ``(batch_size, ...)`` as the first (typically only)
44+
positional argument and return an NumPy array of shape
45+
``(batch_size,)``.
46+
The functions do not have to deal with an additional
47+
sample dimension, as appropriate reshaping is done internally.
3548
resolution : int, optional, default: 20
3649
The number of credibility intervals (CIs) to consider
3750
aggregation : callable or None, optional, default: np.median
@@ -55,6 +68,19 @@ def calibration_error(
5568
The (inferred) variable names.
5669
"""
5770

71+
if test_quantities is not None:
72+
updated_data = compute_test_quantities(
73+
targets=targets,
74+
estimates=estimates,
75+
variable_keys=variable_keys,
76+
variable_names=variable_names,
77+
test_quantities=test_quantities,
78+
)
79+
variable_names = updated_data["variable_names"]
80+
variable_keys = updated_data["variable_keys"]
81+
estimates = updated_data["estimates"]
82+
targets = updated_data["targets"]
83+
5884
samples = dicts_to_arrays(
5985
estimates=estimates,
6086
targets=targets,

bayesflow/diagnostics/metrics/calibration_log_gamma.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
from collections.abc import Mapping, Sequence
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import numpy as np
44
from scipy.stats import binom
55

6-
from ...utils.dict_utils import dicts_to_arrays
6+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
77

88

99
def calibration_log_gamma(
1010
estimates: Mapping[str, np.ndarray] | np.ndarray,
1111
targets: Mapping[str, np.ndarray] | np.ndarray,
1212
variable_keys: Sequence[str] = None,
1313
variable_names: Sequence[str] = None,
14+
test_quantities: dict[str, Callable] = None,
1415
num_null_draws: int = 1000,
1516
quantile: float = 0.05,
1617
):
@@ -41,6 +42,18 @@ def calibration_log_gamma(
4142
By default, select all keys.
4243
variable_names : Sequence[str], optional (default = None)
4344
Optional variable names to show in the output.
45+
test_quantities : dict or None, optional, default: None
46+
A dict that maps plot titles to functions that compute
47+
test quantities based on estimate/target draws.
48+
49+
The dict keys are automatically added to ``variable_keys``
50+
and ``variable_names``.
51+
Test quantity functions are expected to accept a dict of draws with
52+
shape ``(batch_size, ...)`` as the first (typically only)
53+
positional argument and return an NumPy array of shape
54+
``(batch_size,)``.
55+
The functions do not have to deal with an additional
56+
sample dimension, as appropriate reshaping is done internally.
4457
quantile : float in (0, 1), optional, default 0.05
4558
The quantile from the null distribution to be used as a threshold.
4659
A lower quantile increases sensitivity to deviations from uniformity.
@@ -57,6 +70,21 @@ def calibration_log_gamma(
5770
- "variable_names" : str
5871
The (inferred) variable names.
5972
"""
73+
74+
# Optionally, compute and prepend test quantities from draws
75+
if test_quantities is not None:
76+
updated_data = compute_test_quantities(
77+
targets=targets,
78+
estimates=estimates,
79+
variable_keys=variable_keys,
80+
variable_names=variable_names,
81+
test_quantities=test_quantities,
82+
)
83+
variable_names = updated_data["variable_names"]
84+
variable_keys = updated_data["variable_keys"]
85+
estimates = updated_data["estimates"]
86+
targets = updated_data["targets"]
87+
6088
samples = dicts_to_arrays(
6189
estimates=estimates,
6290
targets=targets,

bayesflow/diagnostics/metrics/posterior_contraction.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import numpy as np
44

5-
from ...utils.dict_utils import dicts_to_arrays
5+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
66

77

88
def posterior_contraction(
99
estimates: Mapping[str, np.ndarray] | np.ndarray,
1010
targets: Mapping[str, np.ndarray] | np.ndarray,
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
13+
test_quantities: dict[str, Callable] = None,
1314
aggregation: Callable | None = np.median,
1415
) -> dict[str, any]:
1516
"""
@@ -27,6 +28,18 @@ def posterior_contraction(
2728
By default, select all keys.
2829
variable_names : Sequence[str], optional (default = None)
2930
Optional variable names to show in the output.
31+
test_quantities : dict or None, optional, default: None
32+
A dict that maps plot titles to functions that compute
33+
test quantities based on estimate/target draws.
34+
35+
The dict keys are automatically added to ``variable_keys``
36+
and ``variable_names``.
37+
Test quantity functions are expected to accept a dict of draws with
38+
shape ``(batch_size, ...)`` as the first (typically only)
39+
positional argument and return an NumPy array of shape
40+
``(batch_size,)``.
41+
The functions do not have to deal with an additional
42+
sample dimension, as appropriate reshaping is done internally.
3043
aggregation : callable or None, optional (default = np.median)
3144
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
3245
If None is provided, the individual values are returned.
@@ -50,6 +63,20 @@ def posterior_contraction(
5063
indicate low contraction.
5164
"""
5265

66+
# Optionally, compute and prepend test quantities from draws
67+
if test_quantities is not None:
68+
updated_data = compute_test_quantities(
69+
targets=targets,
70+
estimates=estimates,
71+
variable_keys=variable_keys,
72+
variable_names=variable_names,
73+
test_quantities=test_quantities,
74+
)
75+
variable_names = updated_data["variable_names"]
76+
variable_keys = updated_data["variable_keys"]
77+
estimates = updated_data["estimates"]
78+
targets = updated_data["targets"]
79+
5380
samples = dicts_to_arrays(
5481
estimates=estimates,
5582
targets=targets,

bayesflow/diagnostics/metrics/root_mean_squared_error.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import numpy as np
44

5-
from ...utils.dict_utils import dicts_to_arrays
5+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
66

77

88
def root_mean_squared_error(
99
estimates: Mapping[str, np.ndarray] | np.ndarray,
1010
targets: Mapping[str, np.ndarray] | np.ndarray,
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
13+
test_quantities: dict[str, Callable] = None,
1314
normalize: str | None = "range",
1415
aggregation: Callable = np.median,
1516
) -> dict[str, any]:
@@ -28,6 +29,18 @@ def root_mean_squared_error(
2829
By default, select all keys.
2930
variable_names : Sequence[str], optional (default = None)
3031
Optional variable names to show in the output.
32+
test_quantities : dict or None, optional, default: None
33+
A dict that maps plot titles to functions that compute
34+
test quantities based on estimate/target draws.
35+
36+
The dict keys are automatically added to ``variable_keys``
37+
and ``variable_names``.
38+
Test quantity functions are expected to accept a dict of draws with
39+
shape ``(batch_size, ...)`` as the first (typically only)
40+
positional argument and return an NumPy array of shape
41+
``(batch_size,)``.
42+
The functions do not have to deal with an additional
43+
sample dimension, as appropriate reshaping is done internally.
3144
normalize : str or None, optional (default = "range")
3245
Whether to normalize the RMSE using statistics of the prior samples.
3346
Possible options are ("mean", "range", "median", "iqr", "std", None)
@@ -52,6 +65,20 @@ def root_mean_squared_error(
5265
The (inferred) variable names.
5366
"""
5467

68+
# Optionally, compute and prepend test quantities from draws
69+
if test_quantities is not None:
70+
updated_data = compute_test_quantities(
71+
targets=targets,
72+
estimates=estimates,
73+
variable_keys=variable_keys,
74+
variable_names=variable_names,
75+
test_quantities=test_quantities,
76+
)
77+
variable_names = updated_data["variable_names"]
78+
variable_keys = updated_data["variable_keys"]
79+
estimates = updated_data["estimates"]
80+
targets = updated_data["targets"]
81+
5582
samples = dicts_to_arrays(
5683
estimates=estimates,
5784
targets=targets,

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ def calibration_ecdf(
1515
variable_keys: Sequence[str] = None,
1616
variable_names: Sequence[str] = None,
1717
test_quantities: dict[str, Callable] = None,
18-
difference: bool = False,
18+
difference: bool = True,
1919
stacked: bool = False,
2020
rank_type: str | np.ndarray = "fractional",
2121
figsize: Sequence[float] = None,
2222
label_fontsize: int = 16,
2323
legend_fontsize: int = 14,
24-
legend_location: str = "upper right",
24+
legend_location: str = "lower right",
2525
title_fontsize: int = 18,
2626
tick_fontsize: int = 12,
2727
rank_ecdf_color: str = "#132a70",
@@ -59,7 +59,7 @@ def calibration_ecdf(
5959
The posterior draws obtained from n_data_sets
6060
targets : np.ndarray of shape (n_data_sets, n_params)
6161
The prior draws obtained for generating n_data_sets
62-
difference : bool, optional, default: False
62+
difference : bool, optional, default: True
6363
If `True`, plots the ECDF difference.
6464
Enables a more dynamic visualization range.
6565
stacked : bool, optional, default: False
@@ -98,7 +98,9 @@ def calibration_ecdf(
9898
label_fontsize : int, optional, default: 16
9999
The font size of the y-label and y-label texts
100100
legend_fontsize : int, optional, default: 14
101-
The font size of the legend text
101+
The font size of the legend text.
102+
legend_location : str, optional, default: 'lower right
103+
The location of the legend.
102104
title_fontsize : int, optional, default: 18
103105
The font size of the title text.
104106
Only relevant if `stacked=False`
@@ -211,11 +213,13 @@ def calibration_ecdf(
211213
else:
212214
titles = ["Stacked ECDFs"]
213215

214-
for ax, title in zip(plot_data["axes"].flat, titles):
216+
for i, (ax, title) in enumerate(zip(plot_data["axes"].flat, titles)):
215217
ax.fill_between(z, L, U, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
216-
ax.legend(fontsize=legend_fontsize, loc=legend_location)
217218
ax.set_title(title, fontsize=title_fontsize)
218219

220+
if i == 0:
221+
ax.legend(fontsize=legend_fontsize, loc=legend_location)
222+
219223
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
220224

221225
add_titles_and_labels(

bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ def calibration_ecdf_from_quantiles(
1414
quantiles_key: str = "quantiles",
1515
variable_keys: Sequence[str] = None,
1616
variable_names: Sequence[str] = None,
17-
difference: bool = False,
17+
difference: bool = True,
1818
stacked: bool = False,
1919
figsize: Sequence[float] = None,
2020
label_fontsize: int = 16,
2121
legend_fontsize: int = 14,
22-
legend_location: str = "upper right",
22+
legend_location: str = "lower right",
2323
title_fontsize: int = 18,
2424
tick_fontsize: int = 12,
2525
rank_ecdf_color: str = "#132a70",
@@ -69,7 +69,7 @@ def calibration_ecdf_from_quantiles(
6969
variable_names : list or None, optional, default: None
7070
The parameter names for nice plot titles.
7171
Inferred if None. Only relevant if `stacked=False`.
72-
difference : bool, optional, default: False
72+
difference : bool, optional, default: True
7373
If `True`, plots the ECDF difference.
7474
Enables a more dynamic visualization range.
7575
stacked : bool, optional, default: False
@@ -82,7 +82,9 @@ def calibration_ecdf_from_quantiles(
8282
label_fontsize : int, optional, default: 16
8383
The font size of the y-label and y-label texts
8484
legend_fontsize : int, optional, default: 14
85-
The font size of the legend text
85+
The font size of the legend text.
86+
legend_location : str, optional, default: 'lower right
87+
The location of the legend.
8688
title_fontsize : int, optional, default: 18
8789
The font size of the title text.
8890
Only relevant if `stacked=False`

bayesflow/diagnostics/plots/calibration_histogram.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Sequence, Mapping
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import matplotlib.pyplot as plt
44
import numpy as np
@@ -8,13 +8,15 @@
88

99
from bayesflow.utils import logging
1010
from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
11+
from bayesflow.utils.dict_utils import compute_test_quantities
1112

1213

1314
def calibration_histogram(
1415
estimates: Mapping[str, np.ndarray] | np.ndarray,
1516
targets: Mapping[str, np.ndarray] | np.ndarray,
1617
variable_keys: Sequence[str] = None,
1718
variable_names: Sequence[str] = None,
19+
test_quantities: dict[str, Callable] = None,
1820
figsize: Sequence[float] = None,
1921
num_bins: int = 10,
2022
binomial_interval: float = 0.99,
@@ -46,6 +48,18 @@ def calibration_histogram(
4648
By default, select all keys.
4749
variable_names : list or None, optional, default: None
4850
The parameter names for nice plot titles. Inferred if None
51+
test_quantities : dict or None, optional, default: None
52+
A dict that maps plot titles to functions that compute
53+
test quantities based on estimate/target draws.
54+
55+
The dict keys are automatically added to ``variable_keys``
56+
and ``variable_names``.
57+
Test quantity functions are expected to accept a dict of draws with
58+
shape ``(batch_size, ...)`` as the first (typically only)
59+
positional argument and return an NumPy array of shape
60+
``(batch_size,)``.
61+
The functions do not have to deal with an additional
62+
sample dimension, as appropriate reshaping is done internally.
4963
figsize : tuple or None, optional, default : None
5064
The figure size passed to the matplotlib constructor. Inferred if None
5165
num_bins : int, optional, default: 10
@@ -75,6 +89,20 @@ def calibration_histogram(
7589
If there is a deviation form the expected shapes of `estimates` and `targets`.
7690
"""
7791

92+
# Optionally, compute and prepend test quantities from draws
93+
if test_quantities is not None:
94+
updated_data = compute_test_quantities(
95+
targets=targets,
96+
estimates=estimates,
97+
variable_keys=variable_keys,
98+
variable_names=variable_names,
99+
test_quantities=test_quantities,
100+
)
101+
variable_names = updated_data["variable_names"]
102+
variable_keys = updated_data["variable_keys"]
103+
estimates = updated_data["estimates"]
104+
targets = updated_data["targets"]
105+
78106
plot_data = prepare_plot_data(
79107
estimates=estimates,
80108
targets=targets,

0 commit comments

Comments
 (0)