Skip to content

Commit 0e2d646

Browse files
committed
Add posterior z score
1 parent fde17e2 commit 0e2d646

File tree

4 files changed

+131
-0
lines changed

4 files changed

+131
-0
lines changed

bayesflow/diagnostics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
calibration_error,
88
calibration_log_gamma,
99
posterior_contraction,
10+
posterior_z_score,
1011
summary_space_comparison,
1112
)
1213

bayesflow/diagnostics/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from .classifier_two_sample_test import classifier_two_sample_test
66
from .model_misspecification import bootstrap_comparison, summary_space_comparison
77
from .calibration_log_gamma import calibration_log_gamma, gamma_null_distribution, gamma_discrepancy
8+
from .posterior_z_score import posterior_z_score
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from collections.abc import Sequence, Mapping, Callable
2+
3+
import numpy as np
4+
5+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
6+
7+
8+
def posterior_z_score(
9+
estimates: Mapping[str, np.ndarray] | np.ndarray,
10+
targets: Mapping[str, np.ndarray] | np.ndarray,
11+
variable_keys: Sequence[str] = None,
12+
variable_names: Sequence[str] = None,
13+
test_quantities: dict[str, Callable] = None,
14+
aggregation: Callable | None = np.median,
15+
) -> dict[str, any]:
16+
"""
17+
Computes the posterior z-score from prior to posterior for the given samples according to [1]:
18+
19+
post_z_score = (posterior_mean - true_parameters) / posterior_std
20+
21+
The score is adequate if it centers around zero and spreads roughly
22+
in the interval [-3, 3]
23+
24+
[1] Schad, D. J., Betancourt, M., & Vasishth, S. (2021).
25+
Toward a principled Bayesian workflow in cognitive science.
26+
Psychological methods, 26(1), 103.
27+
28+
Paper also available at https://arxiv.org/abs/1904.12765
29+
30+
Parameters
31+
----------
32+
estimates : np.ndarray of shape (num_datasets, num_draws_post, num_variables)
33+
Posterior samples, comprising `num_draws_post` random draws from the posterior distribution
34+
for each data set from `num_datasets`.
35+
targets : np.ndarray of shape (num_datasets, num_variables)
36+
Prior samples, comprising `num_datasets` ground truths.
37+
variable_keys : Sequence[str], optional (default = None)
38+
Select keys from the dictionaries provided in estimates and targets.
39+
By default, select all keys.
40+
variable_names : Sequence[str], optional (default = None)
41+
Optional variable names to show in the output.
42+
test_quantities : dict or None, optional, default: None
43+
A dict that maps plot titles to functions that compute
44+
test quantities based on estimate/target draws.
45+
46+
The dict keys are automatically added to ``variable_keys``
47+
and ``variable_names``.
48+
Test quantity functions are expected to accept a dict of draws with
49+
shape ``(batch_size, ...)`` as the first (typically only)
50+
positional argument and return an NumPy array of shape
51+
``(batch_size,)``.
52+
The functions do not have to deal with an additional
53+
sample dimension, as appropriate reshaping is done internally.
54+
aggregation : callable or None, optional (default = np.median)
55+
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
56+
If None is provided, the individual values are returned.
57+
58+
Returns
59+
-------
60+
result : dict
61+
Dictionary containing:
62+
63+
- "values" : float or np.ndarray
64+
The (optionally aggregated) posterior z-score per variable
65+
- "metric_name" : str
66+
The name of the metric ("Posterior z-score").
67+
- "variable_names" : str
68+
The (inferred) variable names.
69+
70+
Notes
71+
-----
72+
Posterior z-score quantifies how far the posterior mean lies from
73+
the true generating parameter, in standard-error units. Values near 0
74+
(in absolute terms) mean the posterior mean is close to the truth;
75+
large absolute values indicate substantial deviation.
76+
The sign shows the direction of the bias.
77+
78+
"""
79+
80+
# Optionally, compute and prepend test quantities from draws
81+
if test_quantities is not None:
82+
updated_data = compute_test_quantities(
83+
targets=targets,
84+
estimates=estimates,
85+
variable_keys=variable_keys,
86+
variable_names=variable_names,
87+
test_quantities=test_quantities,
88+
)
89+
variable_names = updated_data["variable_names"]
90+
variable_keys = updated_data["variable_keys"]
91+
estimates = updated_data["estimates"]
92+
targets = updated_data["targets"]
93+
94+
samples = dicts_to_arrays(
95+
estimates=estimates,
96+
targets=targets,
97+
variable_keys=variable_keys,
98+
variable_names=variable_names,
99+
)
100+
101+
post_vars = samples["estimates"].var(axis=1, ddof=1)
102+
post_means = samples["estimates"].mean(axis=1)
103+
post_stds = np.sqrt(post_vars)
104+
prior_vars = samples["targets"].var(axis=0, keepdims=True, ddof=1)
105+
z_score = (post_means - samples["targets"]) / post_stds
106+
if aggregation is not None:
107+
z_score = aggregation(z_score, axis=0)
108+
variable_names = samples["estimates"].variable_names
109+
return {"values": z_score, "metric_name": "Posterior z-score", "variable_names": variable_names}

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,26 @@ def test_posterior_contraction(random_estimates, random_targets):
6666
assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates)
6767

6868

69+
def test_posterior_z_score(random_estimates, random_targets):
70+
# basic functionality: automatic variable names
71+
out = bf.diagnostics.metrics.posterior_z_score(random_estimates, random_targets)
72+
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
73+
assert out["values"].shape == (num_variables(random_estimates),)
74+
assert out["metric_name"] == "Posterior z-score"
75+
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
76+
# test without aggregation
77+
out = bf.diagnostics.metrics.posterior_z_score(random_estimates, random_targets, aggregation=None)
78+
assert out["values"].shape == (random_estimates["sigma"].shape[0], num_variables(random_estimates))
79+
80+
# test quantities
81+
test_quantities = {
82+
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
83+
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
84+
}
85+
out = bf.diagnostics.metrics.posterior_z_score(random_estimates, random_targets, test_quantities=test_quantities)
86+
assert out["values"].shape[0] == len(test_quantities) + num_variables(random_estimates)
87+
88+
6989
def test_root_mean_squared_error(random_estimates, random_targets):
7090
# basic functionality: automatic variable names
7191
out = bf.diagnostics.metrics.root_mean_squared_error(random_estimates, random_targets)

0 commit comments

Comments
 (0)