Skip to content

Commit b050ffb

Browse files
add log_gamma diagnostic
1 parent 361fa45 commit b050ffb

File tree

2 files changed

+195
-1
lines changed

2 files changed

+195
-1
lines changed
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from collections.abc import Mapping, Sequence
2+
3+
import numpy as np
4+
from scipy.stats import binom
5+
6+
from ...utils.dict_utils import dicts_to_arrays
7+
8+
9+
def log_gamma(
10+
estimates: Mapping[str, np.ndarray] | np.ndarray,
11+
targets: Mapping[str, np.ndarray] | np.ndarray,
12+
variable_keys: Sequence[str] = None,
13+
variable_names: Sequence[str] = None,
14+
num_null_draws: int = 1000,
15+
quantile: float = 0.05,
16+
):
17+
"""
18+
Compute the log gamma discrepancy statistic, see [1] for additional information.
19+
Log gamma is log(gamma/gamma_null), where gamma_null is the 5th percentile of the
20+
null distribution under uniformity of ranks.
21+
That is, if adopting a hypothesis testing framework,then log_gamma < 0 implies
22+
a rejection of the hypothesis of uniform ranks at the 5\% level.
23+
This diagnostic is typically more sensitive than the Kolmogorov-Smirnoff test or
24+
ChiSq test.
25+
26+
[1] Martin Modrák. Angie H. Moon. Shinyoung Kim. Paul Bürkner. Niko Huurre.
27+
Kateřina Faltejsková. Andrew Gelman. Aki Vehtari.
28+
"Simulation-Based Calibration Checking for Bayesian Computation:
29+
The Choice of Test Quantities Shapes Sensitivity."
30+
Bayesian Anal. 20 (2) 461 - 488, June 2025. https://doi.org/10.1214/23-BA1404
31+
32+
Parameters
33+
----------
34+
estimates : np.ndarray of shape (num_datasets, num_draws, num_variables)
35+
The random draws from the approximate posteriors over ``num_datasets``
36+
targets : np.ndarray of shape (num_datasets, num_variables)
37+
The corresponding ground-truth values sampled from the prior
38+
variable_keys : Sequence[str], optional (default = None)
39+
Select keys from the dictionaries provided in estimates and targets.
40+
By default, select all keys.
41+
variable_names : Sequence[str], optional (default = None)
42+
Optional variable names to show in the output.
43+
quantile : float in (0, 1), optional, default 0.05
44+
The quantile from the null distribution to be used as a threshold.
45+
A lower quantile increases sensitivity to deviations from uniformity.
46+
"""
47+
samples = dicts_to_arrays(
48+
estimates=estimates,
49+
targets=targets,
50+
variable_keys=variable_keys,
51+
variable_names=variable_names,
52+
)
53+
54+
num_ranks = samples["estimates"].shape[0]
55+
num_post_draws = samples["estimates"].shape[1]
56+
57+
# rank statistics
58+
ranks = np.sum(samples["estimates"] < samples["targets"][:, None], axis=1)
59+
60+
# null distribution and threshold
61+
null_distribution = gamma_null_distribution(num_ranks, num_post_draws, num_null_draws)
62+
null_quantile = np.quantile(null_distribution, quantile)
63+
64+
# compute log gamma for each parameter
65+
log_gammas = []
66+
for i in range(ranks.shape[-1]):
67+
gamma = gamma_discrepancy(ranks[:, i], num_post_draws=num_post_draws)
68+
log_gammas.append(np.log(gamma / null_quantile))
69+
70+
output = {
71+
"values": np.array(log_gammas),
72+
"metric_name": "Log Gamma",
73+
"variable_names": samples["estimates"].variable_names,
74+
}
75+
76+
return output
77+
78+
79+
def gamma_null_distribution(num_ranks: int, num_post_draws: int = 1000, num_null_draws: int = 1000) -> np.ndarray:
80+
"""
81+
Computes the distribution of expected gamma values under uniformity of ranks.
82+
83+
Parameters
84+
----------
85+
num_ranks : int
86+
Number of ranks to use for each gamma.
87+
num_post_draws : int, optional, default 1000
88+
Number of posterior draws that were used to calculate the rank distribution.
89+
num_null_draws : int, optional, default 1000
90+
Number of returned gamma values under uniformity of ranks.
91+
92+
Returns
93+
-------
94+
result : np.ndarray
95+
Array of shape (num_null_draws,) containing gamma values under uniformity of ranks.
96+
"""
97+
z_i = np.arange(1, num_post_draws + 2) / (num_post_draws + 1)
98+
gamma = np.empty(num_null_draws)
99+
100+
# loop non-vectorized to reduce memory footprint
101+
for i in range(num_null_draws):
102+
u = np.random.uniform(size=num_ranks)
103+
F_z = np.mean(u[:, None] < z_i, axis=0)
104+
bin_1 = binom.cdf(num_ranks * F_z, num_ranks, z_i)
105+
bin_2 = 1 - binom.cdf(num_ranks * F_z - 1, num_ranks, z_i)
106+
107+
gamma[i] = 2 * np.min(np.minimum(bin_1, bin_2))
108+
109+
return gamma
110+
111+
112+
def gamma_discrepancy(ranks: np.ndarray, num_post_draws: int = 100) -> float:
113+
"""
114+
Quantifies deviation from uniformity by the likelihood of observing the
115+
most extreme point on the empirical CDF of the given rank distribution
116+
according to [1] (equation 7).
117+
118+
[1] Martin Modrák. Angie H. Moon. Shinyoung Kim. Paul Bürkner. Niko Huurre.
119+
Kateřina Faltejsková. Andrew Gelman. Aki Vehtari.
120+
"Simulation-Based Calibration Checking for Bayesian Computation:
121+
The Choice of Test Quantities Shapes Sensitivity."
122+
Bayesian Anal. 20 (2) 461 - 488, June 2025. https://doi.org/10.1214/23-BA1404
123+
124+
Parameters
125+
----------
126+
ranks : array of shape (num_ranks,)
127+
Empirical rank distribution
128+
num_post_draws : int, optional, default 100
129+
Number of posterior draws used to generate ranks.
130+
131+
Returns
132+
-------
133+
result : float
134+
Gamma discrepancy values for each parameter.
135+
"""
136+
num_ranks = len(ranks)
137+
138+
# observed count of ranks smaller than i
139+
R_i = np.array([sum(ranks < i) for i in range(1, num_post_draws + 2)])
140+
141+
# expected proportion of ranks smaller than i
142+
z_i = np.arange(1, num_post_draws + 2) / (num_post_draws + 1)
143+
144+
bin_1 = binom.cdf(R_i, num_ranks, z_i)
145+
bin_2 = 1 - binom.cdf(R_i - 1, num_ranks, z_i)
146+
147+
# likelihood of obtaining the most extreme point on the empirical CDF
148+
# if the rank distribution was indeed uniform
149+
return float(2 * np.min(np.minimum(bin_1, bin_2)))

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import numpy as np
21
import keras
2+
import numpy as np
33
import pytest
4+
from scipy.stats import binom
45

56
import bayesflow as bf
67

@@ -84,6 +85,50 @@ def test_expected_calibration_error(pred_models, true_models, model_names):
8485
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose)
8586

8687

88+
def test_log_gamma():
89+
# This is a function test for simulation-based calibration.
90+
# First, we sample from a known generative process and then run SBC.
91+
# If the log gamma statistic is correctly implemented, a 95% interval should exclude
92+
# the true value 5% of the time.
93+
94+
N = 30 # number of samples
95+
S = 1000 # number of posterior draws
96+
D = 1000 # number of datasets
97+
98+
def run_sbc(N=N, S=S, D=D, bias=0):
99+
rng = np.random.default_rng()
100+
prior_draws = rng.beta(2, 2, size=D)
101+
successes = rng.binomial(N, prior_draws)
102+
103+
# Analytical posterior:
104+
# if theta ~ Beta(2, 2), then p(theta|successes) is Beta(2 + successes | 2 + N - successes).
105+
posterior_draws = rng.beta(2 + successes + bias, 2 + N - successes + bias, size=(S, D))
106+
107+
# these ranks are uniform if bias=0
108+
ranks = np.sum(posterior_draws < prior_draws, axis=0)
109+
110+
# this is the distribution of gamma under uniform ranks
111+
gamma_null = bf.diagnostics.metrics.log_gamma.gamma_null_distribution(D, S, num_null_draws=100)
112+
lower, upper = np.quantile(gamma_null, (0.05, 0.995))
113+
114+
# this is the empirical gamma
115+
observed_gamma = bf.diagnostics.metrics.log_gamma.gamma_discrepancy(ranks, num_post_draws=S)
116+
117+
in_interval = lower <= observed_gamma < upper
118+
119+
return in_interval
120+
121+
sbc_calibration = [run_sbc(N=N, S=S, D=D) for _ in range(100)]
122+
lower_expected, upper_expected = binom.ppf((0.005, 0.995), 100, 0.95)
123+
124+
# this test should fail with a probability of 1%
125+
assert lower_expected <= np.sum(sbc_calibration) <= upper_expected
126+
127+
# sbc should almost always fial for slightly biased posterior draws
128+
sbc_calibration = [run_sbc(N=N, S=S, D=D, bias=1) for _ in range(100)]
129+
assert not lower_expected <= np.sum(sbc_calibration) <= upper_expected
130+
131+
87132
def test_bootstrap_comparison_shapes():
88133
"""Test the bootstrap_comparison output shapes."""
89134
observed_samples = np.random.rand(10, 5)

0 commit comments

Comments
 (0)