Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 72 additions & 6 deletions src/arviz_stats/loo/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
from arviz_base import rcParams
from scipy.optimize import Bounds, LinearConstraint, minimize
from scipy.stats import dirichlet
from scipy.stats import dirichlet, norm

from arviz_stats.loo import loo
from arviz_stats.loo.helper_loo import _diff_srs_estimator
Expand All @@ -23,8 +23,17 @@ def compare(

The ELPD is estimated by Pareto smoothed importance sampling leave-one-out
cross-validation, the same method used by :func:`arviz_stats.loo`.
The method is described in [1]_ and [2]_.
By default, the weights are estimated using ``"stacking"`` as described in [3]_.
The method is described in [2]_ and [3]_.
By default, the weights are estimated using ``"stacking"`` as described in [4]_.

If more than 11 models are compared, a diagnostic check for selection bias
is performed. If detected, avoid LOO-based selection and use model averaging
or `projection predictive inference <https://kulprit.readthedocs.io/en/latest/index.html>`_.

See the EABM chapters on `Model Comparison <https://arviz-devs.github.io/EABM/Chapters/Model_comparison.html>`_,
`Model Comparison (Case Study) <https://arviz-devs.github.io/EABM/Chapters/Case_study_model_comparison.html>`_,
and `Model Comparison for Large Data <https://arviz-devs.github.io/EABM/Chapters/Model_comparison_large_data.html>`_
for more details.

Parameters
----------
Expand Down Expand Up @@ -117,15 +126,20 @@ def compare(
References
----------

.. [1] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
.. [1] McLatchie, Y., Vehtari, A. *Efficient estimation and correction of selection-induced
bias with order statistics*. Statistics and Computing, 34, 132 (2024).
https://doi.org/10.1007/s11222-024-10442-4
arXiv preprint https://arxiv.org/abs/2309.03742

.. [2] Vehtari et al. *Practical Bayesian model evaluation using leave-one-out cross-validation
and WAIC*. Statistics and Computing. 27(5) (2017) https://doi.org/10.1007/s11222-016-9696-4
arXiv preprint https://arxiv.org/abs/1507.04544.

.. [2] Vehtari et al. *Pareto Smoothed Importance Sampling*.
.. [3] Vehtari et al. *Pareto Smoothed Importance Sampling*.
Journal of Machine Learning Research, 25(72) (2024) https://jmlr.org/papers/v25/19-556.html
arXiv preprint https://arxiv.org/abs/1507.02646

.. [3] Yao et al. *Using stacking to average Bayesian predictive distributions*
.. [4] Yao et al. *Using stacking to average Bayesian predictive distributions*
Bayesian Analysis, 13, 3 (2018). https://doi.org/10.1214/17-BA1091
arXiv preprint https://arxiv.org/abs/1704.02030.
"""
Expand Down Expand Up @@ -270,6 +284,9 @@ def gradient(weights):

df_comp["rank"] = df_comp["rank"].astype(int)
df_comp["warning"] = df_comp["warning"].astype(bool)

model_order = list(ics.index)
_order_stat_check(ics_dict, model_order, has_subsampling)
return df_comp.sort_values(by="elpd", ascending=False)


Expand Down Expand Up @@ -529,3 +546,52 @@ def _calculate_ics(
f"Encountered error trying to compute ELPD from model {name}."
) from e
return new_compare_dict


def _order_stat_check(ics_dict, model_order, has_subsampling):
"""Perform order statistics-based checks on models."""
if has_subsampling or len(ics_dict) <= 11:
return

# Use the median model as the baseline model to compute ELPD differences
baseline_idx = len(model_order) // 2
baseline_model = model_order[baseline_idx]
baseline_elpd = ics_dict[baseline_model]

elpd_diffs = []
for model_name in model_order:
if model_name == baseline_model:
elpd_diffs.append(0.0)
else:
elpd_a_vals = np.asarray(baseline_elpd.elpd_i).flatten()
elpd_b_vals = np.asarray(ics_dict[model_name].elpd_i).flatten()
elpd_diffs.append(np.sum(elpd_b_vals - elpd_a_vals))

elpd_diffs = np.array(elpd_diffs)
diff_median = np.median(elpd_diffs)
elpd_diff_trunc = elpd_diffs[elpd_diffs >= diff_median]
n_models = np.sum(~np.isnan(elpd_diff_trunc))

if n_models < 1:
return

candidate_sd = np.sqrt(1 / n_models * np.sum(elpd_diff_trunc**2))

# Defensive check to avoid a runtime error when computing the order statistic
if candidate_sd == 0 or not np.isfinite(candidate_sd):
warnings.warn(
"All models have nearly identical performance.",
UserWarning,
)
return

# Estimate expected best diff under null hypothesis
k = len(ics_dict) - 1
order_stat = norm.ppf(1 - 1 / (k * 2), loc=0, scale=candidate_sd)

if np.nanmax(elpd_diffs) <= order_stat:
warnings.warn(
"Difference in performance potentially due to chance. "
"See https://doi.org/10.1007/s11222-024-10442-4 for details.",
UserWarning,
)
69 changes: 69 additions & 0 deletions tests/loo/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,72 @@ def test_compare_elpd_diff_relative_to_best(centered_eight, non_centered_eight):
for i in range(len(result)):
expected_diff = best_elpd - result.iloc[i]["elpd"]
assert_almost_equal(result.iloc[i]["elpd_diff"], expected_diff, decimal=10)


def test_compare_order_stat_check(centered_eight, rng):
models = {}
for i in range(12):
loo_result = loo(centered_eight, pointwise=True)
shift = rng.normal(0, 0.1, size=loo_result.elpd_i.shape)
loo_result.elpd_i = loo_result.elpd_i + shift
loo_result.elpd = np.sum(loo_result.elpd_i)
models[f"model_{i}"] = loo_result

with pytest.warns(
UserWarning,
match="Difference in performance potentially due to chance.*10.1007/s11222-024-10442-4",
):
result = compare(models)
assert len(result) == 12
assert_allclose(result["weight"].sum(), 1.0)


def test_compare_order_stat_check_identical_models(centered_eight):
models = {f"model_{i}": centered_eight for i in range(12)}
with pytest.warns(UserWarning, match="All models have nearly identical performance"):
result = compare(models)
assert len(result) == 12
assert_allclose(result["elpd"].values, result["elpd"].values[0])
assert_allclose(result["weight"].sum(), 1.0)


def test_compare_order_stat_check_few_models(centered_eight):
models = {f"model_{i}": centered_eight for i in range(11)}
result = compare(models)
assert len(result) == 11


@pytest.mark.filterwarnings("ignore::UserWarning")
def test_compare_order_stat_check_subsampling(centered_eight_with_sigma):
models = {}
for i in range(12):
loo_sub = loo_subsample(
centered_eight_with_sigma,
observations=np.array([0, 1, 2, 3]),
var_name="obs",
method="plpd",
log_lik_fn=log_lik_fn_subsample,
param_names=["theta"],
pointwise=True,
)
models[f"model_{i}"] = loo_sub

result = compare(models)
assert len(result) == 12
assert "subsampling_dse" in result.columns


def test_compare_order_stat_check_different_models(centered_eight):
models = {}
for i in range(12):
loo_result = loo(centered_eight, pointwise=True)
loo_result = copy.deepcopy(loo_result)
shift = 5 - (i * 5)
loo_result.elpd_i = loo_result.elpd_i + shift
loo_result.elpd = np.sum(loo_result.elpd_i)
models[f"model_{i}"] = loo_result

result = compare(models)
assert len(result) == 12
assert result.iloc[0]["elpd"] > result.iloc[-1]["elpd"]
assert result.iloc[-1]["elpd_diff"] > 10