diff --git a/bayesflow/diagnostics/metrics/__init__.py b/bayesflow/diagnostics/metrics/__init__.py index 2acd6b5b1..10d499e11 100644 --- a/bayesflow/diagnostics/metrics/__init__.py +++ b/bayesflow/diagnostics/metrics/__init__.py @@ -4,4 +4,4 @@ from .expected_calibration_error import expected_calibration_error from .classifier_two_sample_test import classifier_two_sample_test from .model_misspecification import bootstrap_comparison, summary_space_comparison -from .sbc import log_gamma +from .calibration_log_gamma import calibration_log_gamma, gamma_null_distribution, gamma_discrepancy diff --git a/bayesflow/diagnostics/metrics/sbc.py b/bayesflow/diagnostics/metrics/calibration_log_gamma.py similarity index 97% rename from bayesflow/diagnostics/metrics/sbc.py rename to bayesflow/diagnostics/metrics/calibration_log_gamma.py index 57b364a63..54551c857 100644 --- a/bayesflow/diagnostics/metrics/sbc.py +++ b/bayesflow/diagnostics/metrics/calibration_log_gamma.py @@ -6,7 +6,7 @@ from ...utils.dict_utils import dicts_to_arrays -def log_gamma( +def calibration_log_gamma( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, @@ -15,7 +15,8 @@ def log_gamma( quantile: float = 0.05, ): """ - Compute the log gamma discrepancy statistic, see [1] for additional information. + Compute the log gamma discrepancy statistic to test posterior calibration, + see [1] for additional information. Log gamma is log(gamma/gamma_null), where gamma_null is the 5th percentile of the null distribution under uniformity of ranks. That is, if adopting a hypothesis testing framework,then log_gamma < 0 implies diff --git a/tests/test_diagnostics/test_diagnostics_metrics.py b/tests/test_diagnostics/test_diagnostics_metrics.py index 789b6ea0c..f2c4c73c4 100644 --- a/tests/test_diagnostics/test_diagnostics_metrics.py +++ b/tests/test_diagnostics/test_diagnostics_metrics.py @@ -85,15 +85,15 @@ def test_expected_calibration_error(pred_models, true_models, model_names): out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose) -def test_log_gamma(random_estimates, random_targets): - out = bf.diagnostics.metrics.log_gamma(random_estimates, random_targets) +def test_calibration_log_gamma(random_estimates, random_targets): + out = bf.diagnostics.metrics.calibration_log_gamma(random_estimates, random_targets) assert list(out.keys()) == ["values", "metric_name", "variable_names"] assert out["values"].shape == (num_variables(random_estimates),) assert out["metric_name"] == "Log Gamma" assert out["variable_names"] == ["beta_0", "beta_1", "sigma"] -def test_log_gamma_end_to_end(): +def test_calibration_log_gamma_end_to_end(): # This is a function test for simulation-based calibration. # First, we sample from a known generative process and then run SBC. # If the log gamma statistic is correctly implemented, a 95% interval should exclude @@ -116,11 +116,11 @@ def run_sbc(N=N, S=S, D=D, bias=0): ranks = np.sum(posterior_draws < prior_draws, axis=0) # this is the distribution of gamma under uniform ranks - gamma_null = bf.diagnostics.metrics.sbc.gamma_null_distribution(D, S, num_null_draws=100) + gamma_null = bf.diagnostics.metrics.gamma_null_distribution(D, S, num_null_draws=100) lower, upper = np.quantile(gamma_null, (0.05, 0.995)) # this is the empirical gamma - observed_gamma = bf.diagnostics.metrics.sbc.gamma_discrepancy(ranks, num_post_draws=S) + observed_gamma = bf.diagnostics.metrics.gamma_discrepancy(ranks, num_post_draws=S) in_interval = lower <= observed_gamma < upper