Skip to content

Commit 1d8c010

Browse files
Add test for compare method standard error sorting consistency (#2407)
* Add test for compare method standard error sorting consistency (#2350) * Fix pylint issues: remove trailing whitespace and add final newline * Fix Black formatting: add blank line before function definition * add to changelog --------- Co-authored-by: Oriol (ProDesk) <[email protected]>
1 parent 2e4cfcf commit 1d8c010

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- `reference_values` and `labeller` now work together in `plot_pair` ([2437](https://github.com/arviz-devs/arviz/issues/2437))
1212
- Fix `plot_lm` for multidimensional data ([2408](https://github.com/arviz-devs/arviz/issues/2408))
1313
- Add [`scipy-stubs`](https://github.com/scipy/scipy-stubs) as a development dependency ([2445](https://github.com/arviz-devs/arviz/pull/2445))
14+
- Test compare dataframe stays consistent independently of input order ([2407](https://github.com/arviz-devs/arviz/pull/2407))
1415

1516
### Documentation
1617
- Added documentation for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))

arviz/tests/base_tests/test_stats.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from xarray import DataArray, Dataset
1515
from xarray_einstats.stats import XrContinuousRV
1616

17-
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data
17+
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData
1818
from ...rcparams import rcParams
1919
from ...stats import (
2020
apply_test_function,
@@ -882,3 +882,44 @@ def test_bayes_factor():
882882
bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
883883
assert bf_dict0["BF10"] > bf_dict0["BF01"]
884884
assert bf_dict1["BF10"] < bf_dict1["BF01"]
885+
886+
887+
def test_compare_sorting_consistency():
888+
chains, draws = 4, 1000
889+
890+
# Model 1 - good fit
891+
log_lik1 = np.random.normal(-2, 1, size=(chains, draws))
892+
posterior1 = Dataset(
893+
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
894+
coords={"chain": range(chains), "draw": range(draws)},
895+
)
896+
log_like1 = Dataset(
897+
{"y": (("chain", "draw"), log_lik1)},
898+
coords={"chain": range(chains), "draw": range(draws)},
899+
)
900+
data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1)
901+
902+
# Model 2 - poor fit (higher variance)
903+
log_lik2 = np.random.normal(-5, 2, size=(chains, draws))
904+
posterior2 = Dataset(
905+
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
906+
coords={"chain": range(chains), "draw": range(draws)},
907+
)
908+
log_like2 = Dataset(
909+
{"y": (("chain", "draw"), log_lik2)},
910+
coords={"chain": range(chains), "draw": range(draws)},
911+
)
912+
data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2)
913+
914+
# Compare models in different orders
915+
comp_dict1 = {"M1": data1, "M2": data2}
916+
comp_dict2 = {"M2": data2, "M1": data1}
917+
918+
comparison1 = compare(comp_dict1, method="bb-pseudo-bma")
919+
comparison2 = compare(comp_dict2, method="bb-pseudo-bma")
920+
921+
assert comparison1.index.tolist() == comparison2.index.tolist()
922+
923+
se1 = comparison1["se"].values
924+
se2 = comparison2["se"].values
925+
np.testing.assert_array_almost_equal(se1, se2)

0 commit comments

Comments
 (0)