Skip to content

Commit a7d8868

Browse files
Add order statistic diagnostic for compare() (#237)
* feat: add order-statistic check for model comparison * docs: make diagnostic check less technical * refactor: improve code quality and fix tests
1 parent 1a03d17 commit a7d8868

File tree

2 files changed

+130
-6
lines changed

2 files changed

+130
-6
lines changed

src/arviz_stats/loo/compare.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88
from arviz_base import rcParams
99
from scipy.optimize import Bounds, LinearConstraint, minimize
10-
from scipy.stats import dirichlet
10+
from scipy.stats import dirichlet, norm
1111

1212
from arviz_stats.loo import loo
1313
from arviz_stats.loo.helper_loo import _diff_srs_estimator
@@ -23,8 +23,17 @@ def compare(
2323
2424
The ELPD is estimated by Pareto smoothed importance sampling leave-one-out
2525
cross-validation, the same method used by :func:`arviz_stats.loo`.
26-
The method is described in [1]_ and [2]_.
27-
By default, the weights are estimated using ``"stacking"`` as described in [3]_.
26+
The method is described in [2]_ and [3]_.
27+
By default, the weights are estimated using ``"stacking"`` as described in [4]_.
28+
29+
If more than 11 models are compared, a diagnostic check for selection bias
30+
is performed. If detected, avoid LOO-based selection and use model averaging
31+
or `projection predictive inference <https://kulprit.readthedocs.io/en/latest/index.html>`_.
32+
33+
See the EABM chapters on `Model Comparison <https://arviz-devs.github.io/EABM/Chapters/Model_comparison.html>`_,
34+
`Model Comparison (Case Study) <https://arviz-devs.github.io/EABM/Chapters/Case_study_model_comparison.html>`_,
35+
and `Model Comparison for Large Data <https://arviz-devs.github.io/EABM/Chapters/Model_comparison_large_data.html>`_
36+
for more details.
2837
2938
Parameters
3039
----------
@@ -117,15 +126,20 @@ def compare(
117126
References
118127
----------
119128
120-
.. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
129+
.. [1] McLatchie, Y., Vehtari, A. *Efficient estimation and correction of selection-induced
130+
bias with order statistics*. Statistics and Computing, 34, 132 (2024).
131+
https://doi.org/10.1007/s11222-024-10442-4
132+
arXiv preprint https://arxiv.org/abs/2309.03742
133+
134+
.. [2] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
121135
and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4
122136
arXiv preprint https://arxiv.org/abs/1507.04544.
123137
124-
.. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*.
138+
.. [3] Vehtari et al. *Pareto Smoothed Importance Sampling*.
125139
Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html
126140
arXiv preprint https://arxiv.org/abs/1507.02646
127141
128-
.. [3] Yao et al. *Using stacking to average Bayesian predictive distributions*
142+
.. [4] Yao et al. *Using stacking to average Bayesian predictive distributions*
129143
Bayesian Analysis, 13, 3 (2018). https://doi.org/10.1214/17-BA1091
130144
arXiv preprint https://arxiv.org/abs/1704.02030.
131145
"""
@@ -270,6 +284,9 @@ def gradient(weights):
270284

271285
df_comp["rank"] = df_comp["rank"].astype(int)
272286
df_comp["warning"] = df_comp["warning"].astype(bool)
287+
288+
model_order = list(ics.index)
289+
_order_stat_check(ics_dict, model_order, has_subsampling)
273290
return df_comp.sort_values(by="elpd", ascending=False)
274291

275292

@@ -529,3 +546,50 @@ def _calculate_ics(
529546
f"Encountered error trying to compute ELPD from model {name}."
530547
) from e
531548
return new_compare_dict
549+
550+
551+
def _order_stat_check(ics_dict, model_order, has_subsampling):
552+
"""Perform order statistics-based checks on models."""
553+
if has_subsampling or len(ics_dict) <= 11:
554+
return
555+
556+
# Use the median model as the baseline model to compute ELPD differences
557+
baseline_idx = len(model_order) // 2
558+
baseline_model = model_order[baseline_idx]
559+
baseline_elpd = ics_dict[baseline_model]
560+
561+
elpd_diffs = np.zeros(len(model_order))
562+
for idx, model_name in enumerate(model_order):
563+
if model_name != baseline_model:
564+
elpd_a_vals = np.ravel(baseline_elpd.elpd_i)
565+
elpd_b_vals = np.ravel(ics_dict[model_name].elpd_i)
566+
elpd_diffs[idx] = np.sum(elpd_b_vals - elpd_a_vals)
567+
568+
elpd_diffs = np.array(elpd_diffs)
569+
diff_median = np.median(elpd_diffs)
570+
elpd_diff_trunc = elpd_diffs[elpd_diffs >= diff_median]
571+
n_models = np.sum(~np.isnan(elpd_diff_trunc))
572+
573+
if n_models < 1:
574+
return
575+
576+
candidate_sd = np.sqrt(1 / n_models * np.sum(elpd_diff_trunc**2))
577+
578+
# Defensive check to avoid a runtime error when computing the order statistic
579+
if candidate_sd == 0 or not np.isfinite(candidate_sd):
580+
warnings.warn(
581+
"All models have nearly identical performance.",
582+
UserWarning,
583+
)
584+
return
585+
586+
# Estimate expected best diff under null hypothesis
587+
k = len(ics_dict) - 1
588+
order_stat = norm.ppf(1 - 1 / (k * 2), loc=0, scale=candidate_sd)
589+
590+
if np.nanmax(elpd_diffs) <= order_stat:
591+
warnings.warn(
592+
"Difference in performance potentially due to chance. "
593+
"See https://doi.org/10.1007/s11222-024-10442-4 for details.",
594+
UserWarning,
595+
)

tests/loo/test_compare.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,63 @@ def test_compare_elpd_diff_relative_to_best(centered_eight, non_centered_eight):
411411
for i in range(len(result)):
412412
expected_diff = best_elpd - result.iloc[i]["elpd"]
413413
assert_almost_equal(result.iloc[i]["elpd_diff"], expected_diff, decimal=10)
414+
415+
416+
def test_compare_order_stat_check(centered_eight, rng):
417+
models = {}
418+
base_loo = loo(centered_eight, pointwise=True)
419+
420+
for i in range(12):
421+
loo_result = copy.deepcopy(base_loo)
422+
shift = rng.normal(0, 0.1, size=loo_result.elpd_i.shape)
423+
loo_result.elpd_i = loo_result.elpd_i + shift
424+
loo_result.elpd = np.sum(loo_result.elpd_i)
425+
models[f"model_{i}"] = loo_result
426+
427+
with pytest.warns(
428+
UserWarning,
429+
match="Difference in performance potentially due to chance.*10.1007/s11222-024-10442-4",
430+
):
431+
result = compare(models)
432+
assert len(result) == 12
433+
assert_allclose(result["weight"].sum(), 1.0)
434+
435+
436+
def test_compare_order_stat_check_identical_models(centered_eight):
437+
models = {f"model_{i}": centered_eight for i in range(12)}
438+
with pytest.warns(UserWarning, match="All models have nearly identical performance"):
439+
result = compare(models)
440+
assert len(result) == 12
441+
assert_allclose(result["elpd"].values, result["elpd"].values[0])
442+
assert_allclose(result["weight"].sum(), 1.0)
443+
444+
445+
def test_compare_order_stat_check_few_models(centered_eight):
446+
models = {f"model_{i}": centered_eight for i in range(11)}
447+
result = compare(models)
448+
assert len(result) == 11
449+
450+
451+
@pytest.mark.filterwarnings("ignore::UserWarning")
452+
def test_compare_order_stat_check_subsampling(centered_eight_with_sigma, rng):
453+
base_loo_sub = loo_subsample(
454+
centered_eight_with_sigma,
455+
observations=np.array([0, 1, 2, 3]),
456+
var_name="obs",
457+
method="plpd",
458+
log_lik_fn=log_lik_fn_subsample,
459+
param_names=["theta"],
460+
pointwise=True,
461+
)
462+
463+
models = {}
464+
for i in range(12):
465+
loo_sub = copy.deepcopy(base_loo_sub)
466+
shift = rng.normal(0, 0.1, size=loo_sub.elpd_i.shape)
467+
loo_sub.elpd_i = loo_sub.elpd_i + shift
468+
loo_sub.elpd = np.sum(loo_sub.elpd_i)
469+
models[f"model_{i}"] = loo_sub
470+
471+
result = compare(models)
472+
assert len(result) == 12
473+
assert "subsampling_dse" in result.columns

0 commit comments

Comments
 (0)