|
14 | 14 | from xarray import DataArray, Dataset |
15 | 15 | from xarray_einstats.stats import XrContinuousRV |
16 | 16 |
|
17 | | -from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data |
| 17 | +from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData |
18 | 18 | from ...rcparams import rcParams |
19 | 19 | from ...stats import ( |
20 | 20 | apply_test_function, |
@@ -882,3 +882,44 @@ def test_bayes_factor(): |
882 | 882 | bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0) |
883 | 883 | assert bf_dict0["BF10"] > bf_dict0["BF01"] |
884 | 884 | assert bf_dict1["BF10"] < bf_dict1["BF01"] |
| 885 | + |
| 886 | + |
| 887 | +def test_compare_sorting_consistency(): |
| 888 | + chains, draws = 4, 1000 |
| 889 | + |
| 890 | + # Model 1 - good fit |
| 891 | + log_lik1 = np.random.normal(-2, 1, size=(chains, draws)) |
| 892 | + posterior1 = Dataset( |
| 893 | + {"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))}, |
| 894 | + coords={"chain": range(chains), "draw": range(draws)}, |
| 895 | + ) |
| 896 | + log_like1 = Dataset( |
| 897 | + {"y": (("chain", "draw"), log_lik1)}, |
| 898 | + coords={"chain": range(chains), "draw": range(draws)}, |
| 899 | + ) |
| 900 | + data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1) |
| 901 | + |
| 902 | + # Model 2 - poor fit (higher variance) |
| 903 | + log_lik2 = np.random.normal(-5, 2, size=(chains, draws)) |
| 904 | + posterior2 = Dataset( |
| 905 | + {"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))}, |
| 906 | + coords={"chain": range(chains), "draw": range(draws)}, |
| 907 | + ) |
| 908 | + log_like2 = Dataset( |
| 909 | + {"y": (("chain", "draw"), log_lik2)}, |
| 910 | + coords={"chain": range(chains), "draw": range(draws)}, |
| 911 | + ) |
| 912 | + data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2) |
| 913 | + |
| 914 | + # Compare models in different orders |
| 915 | + comp_dict1 = {"M1": data1, "M2": data2} |
| 916 | + comp_dict2 = {"M2": data2, "M1": data1} |
| 917 | + |
| 918 | + comparison1 = compare(comp_dict1, method="bb-pseudo-bma") |
| 919 | + comparison2 = compare(comp_dict2, method="bb-pseudo-bma") |
| 920 | + |
| 921 | + assert comparison1.index.tolist() == comparison2.index.tolist() |
| 922 | + |
| 923 | + se1 = comparison1["se"].values |
| 924 | + se2 = comparison2["se"].values |
| 925 | + np.testing.assert_array_almost_equal(se1, se2) |
0 commit comments