Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bayesflow/diagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .expected_calibration_error import expected_calibration_error
from .classifier_two_sample_test import classifier_two_sample_test
from .model_misspecification import bootstrap_comparison, summary_space_comparison
from .sbc import log_gamma
from .calibration_log_gamma import calibration_log_gamma, gamma_null_distribution, gamma_discrepancy
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ...utils.dict_utils import dicts_to_arrays


def log_gamma(
def calibration_log_gamma(
estimates: Mapping[str, np.ndarray] | np.ndarray,
targets: Mapping[str, np.ndarray] | np.ndarray,
variable_keys: Sequence[str] = None,
Expand All @@ -15,7 +15,8 @@ def log_gamma(
quantile: float = 0.05,
):
"""
Compute the log gamma discrepancy statistic, see [1] for additional information.
Compute the log gamma discrepancy statistic to test posterior calibration,
see [1] for additional information.
Log gamma is log(gamma/gamma_null), where gamma_null is the 5th percentile of the
null distribution under uniformity of ranks.
That is, if adopting a hypothesis testing framework,then log_gamma < 0 implies
Expand Down
10 changes: 5 additions & 5 deletions tests/test_diagnostics/test_diagnostics_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def test_expected_calibration_error(pred_models, true_models, model_names):
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose)


def test_log_gamma(random_estimates, random_targets):
out = bf.diagnostics.metrics.log_gamma(random_estimates, random_targets)
def test_calibration_log_gamma(random_estimates, random_targets):
out = bf.diagnostics.metrics.calibration_log_gamma(random_estimates, random_targets)
assert list(out.keys()) == ["values", "metric_name", "variable_names"]
assert out["values"].shape == (num_variables(random_estimates),)
assert out["metric_name"] == "Log Gamma"
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]


def test_log_gamma_end_to_end():
def test_calibration_log_gamma_end_to_end():
# This is a function test for simulation-based calibration.
# First, we sample from a known generative process and then run SBC.
# If the log gamma statistic is correctly implemented, a 95% interval should exclude
Expand All @@ -116,11 +116,11 @@ def run_sbc(N=N, S=S, D=D, bias=0):
ranks = np.sum(posterior_draws < prior_draws, axis=0)

# this is the distribution of gamma under uniform ranks
gamma_null = bf.diagnostics.metrics.sbc.gamma_null_distribution(D, S, num_null_draws=100)
gamma_null = bf.diagnostics.metrics.gamma_null_distribution(D, S, num_null_draws=100)
lower, upper = np.quantile(gamma_null, (0.05, 0.995))

# this is the empirical gamma
observed_gamma = bf.diagnostics.metrics.sbc.gamma_discrepancy(ranks, num_post_draws=S)
observed_gamma = bf.diagnostics.metrics.gamma_discrepancy(ranks, num_post_draws=S)

in_interval = lower <= observed_gamma < upper

Expand Down