Skip to content

Commit 0b09e55

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into fix-standardization-layer
2 parents 0b32998 + 0c99bd9 commit 0b09e55

File tree

7 files changed

+246
-26
lines changed

7 files changed

+246
-26
lines changed

bayesflow/diagnostics/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .expected_calibration_error import expected_calibration_error
55
from .classifier_two_sample_test import classifier_two_sample_test
66
from .model_misspecification import bootstrap_comparison, summary_space_comparison
7+
from .sbc import log_gamma
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
from collections.abc import Mapping, Sequence
2+
3+
import numpy as np
4+
from scipy.stats import binom
5+
6+
from ...utils.dict_utils import dicts_to_arrays
7+
8+
9+
def log_gamma(
10+
estimates: Mapping[str, np.ndarray] | np.ndarray,
11+
targets: Mapping[str, np.ndarray] | np.ndarray,
12+
variable_keys: Sequence[str] = None,
13+
variable_names: Sequence[str] = None,
14+
num_null_draws: int = 1000,
15+
quantile: float = 0.05,
16+
):
17+
"""
18+
Compute the log gamma discrepancy statistic, see [1] for additional information.
19+
Log gamma is log(gamma/gamma_null), where gamma_null is the 5th percentile of the
20+
null distribution under uniformity of ranks.
21+
That is, if adopting a hypothesis testing framework,then log_gamma < 0 implies
22+
a rejection of the hypothesis of uniform ranks at the 5% level.
23+
This diagnostic is typically more sensitive than the Kolmogorov-Smirnoff test or
24+
ChiSq test.
25+
26+
[1] Martin Modrák. Angie H. Moon. Shinyoung Kim. Paul Bürkner. Niko Huurre.
27+
Kateřina Faltejsková. Andrew Gelman. Aki Vehtari.
28+
"Simulation-Based Calibration Checking for Bayesian Computation:
29+
The Choice of Test Quantities Shapes Sensitivity."
30+
Bayesian Anal. 20 (2) 461 - 488, June 2025. https://doi.org/10.1214/23-BA1404
31+
32+
Parameters
33+
----------
34+
estimates : np.ndarray of shape (num_datasets, num_draws, num_variables)
35+
The random draws from the approximate posteriors over ``num_datasets``
36+
targets : np.ndarray of shape (num_datasets, num_variables)
37+
The corresponding ground-truth values sampled from the prior
38+
variable_keys : Sequence[str], optional (default = None)
39+
Select keys from the dictionaries provided in estimates and targets.
40+
By default, select all keys.
41+
variable_names : Sequence[str], optional (default = None)
42+
Optional variable names to show in the output.
43+
quantile : float in (0, 1), optional, default 0.05
44+
The quantile from the null distribution to be used as a threshold.
45+
A lower quantile increases sensitivity to deviations from uniformity.
46+
47+
Returns
48+
-------
49+
result : dict
50+
Dictionary containing:
51+
52+
- "values" : float or np.ndarray
53+
The log gamma values per variable
54+
- "metric_name" : str
55+
The name of the metric ("Log Gamma").
56+
- "variable_names" : str
57+
The (inferred) variable names.
58+
"""
59+
samples = dicts_to_arrays(
60+
estimates=estimates,
61+
targets=targets,
62+
variable_keys=variable_keys,
63+
variable_names=variable_names,
64+
)
65+
66+
num_ranks = samples["estimates"].shape[0]
67+
num_post_draws = samples["estimates"].shape[1]
68+
69+
# rank statistics
70+
ranks = np.sum(samples["estimates"] < samples["targets"][:, None], axis=1)
71+
72+
# null distribution and threshold
73+
null_distribution = gamma_null_distribution(num_ranks, num_post_draws, num_null_draws)
74+
null_quantile = np.quantile(null_distribution, quantile)
75+
76+
# compute log gamma for each parameter
77+
log_gammas = np.empty(ranks.shape[-1])
78+
79+
for i in range(ranks.shape[-1]):
80+
gamma = gamma_discrepancy(ranks[:, i], num_post_draws=num_post_draws)
81+
log_gammas[i] = np.log(gamma / null_quantile)
82+
83+
output = {
84+
"values": log_gammas,
85+
"metric_name": "Log Gamma",
86+
"variable_names": samples["estimates"].variable_names,
87+
}
88+
89+
return output
90+
91+
92+
def gamma_null_distribution(num_ranks: int, num_post_draws: int = 1000, num_null_draws: int = 1000) -> np.ndarray:
93+
"""
94+
Computes the distribution of expected gamma values under uniformity of ranks.
95+
96+
Parameters
97+
----------
98+
num_ranks : int
99+
Number of ranks to use for each gamma.
100+
num_post_draws : int, optional, default 1000
101+
Number of posterior draws that were used to calculate the rank distribution.
102+
num_null_draws : int, optional, default 1000
103+
Number of returned gamma values under uniformity of ranks.
104+
105+
Returns
106+
-------
107+
result : np.ndarray
108+
Array of shape (num_null_draws,) containing gamma values under uniformity of ranks.
109+
"""
110+
z_i = np.arange(1, num_post_draws + 2) / (num_post_draws + 1)
111+
gamma = np.empty(num_null_draws)
112+
113+
# loop non-vectorized to reduce memory footprint
114+
for i in range(num_null_draws):
115+
u = np.random.uniform(size=num_ranks)
116+
F_z = np.mean(u[:, None] < z_i, axis=0)
117+
bin_1 = binom.cdf(num_ranks * F_z, num_ranks, z_i)
118+
bin_2 = 1 - binom.cdf(num_ranks * F_z - 1, num_ranks, z_i)
119+
120+
gamma[i] = 2 * np.min(np.minimum(bin_1, bin_2))
121+
122+
return gamma
123+
124+
125+
def gamma_discrepancy(ranks: np.ndarray, num_post_draws: int = 100) -> float:
126+
"""
127+
Quantifies deviation from uniformity by the likelihood of observing the
128+
most extreme point on the empirical CDF of the given rank distribution
129+
according to [1] (equation 7).
130+
131+
[1] Martin Modrák. Angie H. Moon. Shinyoung Kim. Paul Bürkner. Niko Huurre.
132+
Kateřina Faltejsková. Andrew Gelman. Aki Vehtari.
133+
"Simulation-Based Calibration Checking for Bayesian Computation:
134+
The Choice of Test Quantities Shapes Sensitivity."
135+
Bayesian Anal. 20 (2) 461 - 488, June 2025. https://doi.org/10.1214/23-BA1404
136+
137+
Parameters
138+
----------
139+
ranks : array of shape (num_ranks,)
140+
Empirical rank distribution
141+
num_post_draws : int, optional, default 100
142+
Number of posterior draws used to generate ranks.
143+
144+
Returns
145+
-------
146+
result : float
147+
Gamma discrepancy values for each parameter.
148+
"""
149+
num_ranks = len(ranks)
150+
151+
# observed count of ranks smaller than i
152+
R_i = np.array([sum(ranks < i) for i in range(1, num_post_draws + 2)])
153+
154+
# expected proportion of ranks smaller than i
155+
z_i = np.arange(1, num_post_draws + 2) / (num_post_draws + 1)
156+
157+
bin_1 = binom.cdf(R_i, num_ranks, z_i)
158+
bin_2 = 1 - binom.cdf(R_i - 1, num_ranks, z_i)
159+
160+
# likelihood of obtaining the most extreme point on the empirical CDF
161+
# if the rank distribution was indeed uniform
162+
return float(2 * np.min(np.minimum(bin_1, bin_2)))

bayesflow/distributions/diagonal_normal.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def __init__(
5858
self.seed_generator = seed_generator or keras.random.SeedGenerator()
5959

6060
self.dim = None
61-
self.log_normalization_constant = None
6261
self._mean = None
6362
self._std = None
6463

@@ -71,17 +70,18 @@ def build(self, input_shape: Shape) -> None:
7170
self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32")
7271
self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32")
7372

74-
self.log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self.std))
75-
7673
if self.trainable_parameters:
7774
self._mean = self.add_weight(
7875
shape=ops.shape(self.mean),
79-
initializer=keras.initializers.get(self.mean),
76+
initializer=keras.initializers.get(keras.ops.copy(self.mean)),
8077
dtype="float32",
8178
trainable=True,
8279
)
8380
self._std = self.add_weight(
84-
shape=ops.shape(self.std), initializer=keras.initializers.get(self.std), dtype="float32", trainable=True
81+
shape=ops.shape(self.std),
82+
initializer=keras.initializers.get(keras.ops.copy(self.std)),
83+
dtype="float32",
84+
trainable=True,
8585
)
8686
else:
8787
self._mean = self.mean
@@ -91,7 +91,8 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
9191
result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1)
9292

9393
if normalize:
94-
result += self.log_normalization_constant
94+
log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std))
95+
result += log_normalization_constant
9596

9697
return result
9798

bayesflow/distributions/diagonal_student_t.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __init__(
6363

6464
self.seed_generator = seed_generator or keras.random.SeedGenerator()
6565

66-
self.log_normalization_constant = None
6766
self.dim = None
6867
self._loc = None
6968
self._scale = None
@@ -78,21 +77,16 @@ def build(self, input_shape: Shape) -> None:
7877
self.loc = ops.cast(ops.broadcast_to(self.loc, (self.dim,)), "float32")
7978
self.scale = ops.cast(ops.broadcast_to(self.scale, (self.dim,)), "float32")
8079

81-
self.log_normalization_constant = (
82-
-0.5 * self.dim * math.log(self.df)
83-
- 0.5 * self.dim * math.log(math.pi)
84-
- math.lgamma(0.5 * self.df)
85-
+ math.lgamma(0.5 * (self.df + self.dim))
86-
- ops.sum(keras.ops.log(self.scale))
87-
)
88-
8980
if self.trainable_parameters:
9081
self._loc = self.add_weight(
91-
shape=ops.shape(self.loc), initializer=keras.initializers.get(self.loc), dtype="float32", trainable=True
82+
shape=ops.shape(self.loc),
83+
initializer=keras.initializers.get(keras.ops.copy(self.loc)),
84+
dtype="float32",
85+
trainable=True,
9286
)
9387
self._scale = self.add_weight(
9488
shape=ops.shape(self.scale),
95-
initializer=keras.initializers.get(self.scale),
89+
initializer=keras.initializers.get(keras.ops.copy(self.scale)),
9690
dtype="float32",
9791
trainable=True,
9892
)
@@ -105,7 +99,14 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor:
10599
result = -0.5 * (self.df + self.dim) * ops.log1p(mahalanobis_term / self.df)
106100

107101
if normalize:
108-
result += self.log_normalization_constant
102+
log_normalization_constant = (
103+
-0.5 * self.dim * math.log(self.df)
104+
- 0.5 * self.dim * math.log(math.pi)
105+
- math.lgamma(0.5 * self.df)
106+
+ math.lgamma(0.5 * (self.df + self.dim))
107+
- ops.sum(keras.ops.log(self._scale))
108+
)
109+
result += log_normalization_constant
109110

110111
return result
111112

bayesflow/distributions/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def build(self, input_shape: Shape) -> None:
144144

145145
self._mixture_logits = self.add_weight(
146146
shape=(len(self.distributions),),
147-
initializer=keras.initializers.get(self.mixture_logits),
147+
initializer=keras.initializers.get(keras.ops.copy(self.mixture_logits)),
148148
dtype="float32",
149149
trainable=self.trainable_mixture,
150150
)

bayesflow/scores/multivariate_normal_score.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,15 @@ def log_prob(self, x: Tensor, mean: Tensor, cov_chol: Tensor) -> Tensor:
8282
"""
8383
diff = x - mean
8484

85-
# Calculate covariance from Cholesky factors
86-
covariance = keras.ops.matmul(
87-
cov_chol,
88-
keras.ops.swapaxes(cov_chol, -2, -1),
85+
# Calculate precision from Cholesky factors of covariance matrix
86+
cov_chol_inv = keras.ops.inv(cov_chol)
87+
precision = keras.ops.matmul(
88+
keras.ops.swapaxes(cov_chol_inv, -2, -1),
89+
cov_chol_inv,
8990
)
90-
precision = keras.ops.inv(covariance)
91-
log_det_covariance = keras.ops.slogdet(covariance)[1] # Only take the log of the determinant part
91+
92+
# Compute log determinant, exploiting Cholesky factors
93+
log_det_covariance = keras.ops.log(keras.ops.prod(keras.ops.diagonal(cov_chol, axis1=1, axis2=2), axis=1)) * 2
9294

9395
# Compute the quadratic term in the exponential of the multivariate Gaussian
9496
quadratic_term = keras.ops.einsum("...i,...ij,...j->...", diff, precision, diff)

tests/test_diagnostics/test_diagnostics_metrics.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import numpy as np
21
import keras
2+
import numpy as np
33
import pytest
4+
from scipy.stats import binom
45

56
import bayesflow as bf
67

@@ -84,6 +85,58 @@ def test_expected_calibration_error(pred_models, true_models, model_names):
8485
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose)
8586

8687

88+
def test_log_gamma(random_estimates, random_targets):
89+
out = bf.diagnostics.metrics.log_gamma(random_estimates, random_targets)
90+
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
91+
assert out["values"].shape == (num_variables(random_estimates),)
92+
assert out["metric_name"] == "Log Gamma"
93+
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]
94+
95+
96+
def test_log_gamma_end_to_end():
97+
# This is a function test for simulation-based calibration.
98+
# First, we sample from a known generative process and then run SBC.
99+
# If the log gamma statistic is correctly implemented, a 95% interval should exclude
100+
# the true value 5% of the time.
101+
102+
N = 30 # number of samples
103+
S = 1000 # number of posterior draws
104+
D = 1000 # number of datasets
105+
106+
def run_sbc(N=N, S=S, D=D, bias=0):
107+
rng = np.random.default_rng()
108+
prior_draws = rng.beta(2, 2, size=D)
109+
successes = rng.binomial(N, prior_draws)
110+
111+
# Analytical posterior:
112+
# if theta ~ Beta(2, 2), then p(theta|successes) is Beta(2 + successes | 2 + N - successes).
113+
posterior_draws = rng.beta(2 + successes + bias, 2 + N - successes + bias, size=(S, D))
114+
115+
# these ranks are uniform if bias=0
116+
ranks = np.sum(posterior_draws < prior_draws, axis=0)
117+
118+
# this is the distribution of gamma under uniform ranks
119+
gamma_null = bf.diagnostics.metrics.sbc.gamma_null_distribution(D, S, num_null_draws=100)
120+
lower, upper = np.quantile(gamma_null, (0.05, 0.995))
121+
122+
# this is the empirical gamma
123+
observed_gamma = bf.diagnostics.metrics.sbc.gamma_discrepancy(ranks, num_post_draws=S)
124+
125+
in_interval = lower <= observed_gamma < upper
126+
127+
return in_interval
128+
129+
sbc_calibration = [run_sbc(N=N, S=S, D=D) for _ in range(100)]
130+
lower_expected, upper_expected = binom.ppf((0.0005, 0.9995), 100, 0.95)
131+
132+
# this test should fail with a probability of 0.1%
133+
assert lower_expected <= np.sum(sbc_calibration) <= upper_expected
134+
135+
# sbc should almost always fial for slightly biased posterior draws
136+
sbc_calibration = [run_sbc(N=N, S=S, D=D, bias=1) for _ in range(100)]
137+
assert not lower_expected <= np.sum(sbc_calibration) <= upper_expected
138+
139+
87140
def test_bootstrap_comparison_shapes():
88141
"""Test the bootstrap_comparison output shapes."""
89142
observed_samples = np.random.rand(10, 5)

0 commit comments

Comments
 (0)