Skip to content

Commit 22812fd

Browse files
aloctavodiatwiecki
authored andcommitted
test compare
1 parent b29b2f6 commit 22812fd

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

pymc3/tests/test_stats.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,55 @@ def test_log_post_trace():
4747
npt.assert_allclose(logp, -0.5 * np.log(2 * np.pi), atol=1e-7)
4848

4949

50+
def test_compare():
51+
np.random.seed(42)
52+
x_obs = np.random.normal(0, 1, size=100)
53+
54+
with pm.Model() as model0:
55+
mu = pm.Normal('mu', 0, 1)
56+
x = pm.Normal('x', mu=mu, sd=1, observed=x_obs)
57+
trace0 = pm.sample(1000)
58+
59+
with pm.Model() as model1:
60+
mu = pm.Normal('mu', 0, 1)
61+
x = pm.Normal('x', mu=mu, sd=0.8, observed=x_obs)
62+
trace1 = pm.sample(1000)
63+
64+
with pm.Model() as model2:
65+
mu = pm.Normal('mu', 0, 1)
66+
x = pm.StudentT('x', nu=1, mu=mu, lam=1, observed=x_obs)
67+
trace2 = pm.sample(1000)
68+
69+
traces = [trace0] * 2
70+
models = [model0] * 2
71+
72+
w_st = pm.compare(traces, models, method='stacking')['weight']
73+
w_bb_bma = pm.compare(traces, models, method='BB-pseudo-BMA')['weight']
74+
w_bma = pm.compare(traces, models, method='pseudo-BMA')['weight']
75+
76+
assert_almost_equal(w_st[0], w_st[1])
77+
assert_almost_equal(w_bb_bma[0], w_bb_bma[1])
78+
assert_almost_equal(w_bma[0], w_bma[1])
79+
80+
assert_almost_equal(np.sum(w_st), 1.)
81+
assert_almost_equal(np.sum(w_bb_bma), 1.)
82+
assert_almost_equal(np.sum(w_bma), 1.)
83+
84+
traces = [trace0, trace1, trace2]
85+
models = [model0, model1, model2]
86+
w_st = pm.compare(traces, models, method='stacking')['weight']
87+
w_bb_bma = pm.compare(traces, models, method='BB-pseudo-BMA')['weight']
88+
w_bma = pm.compare(traces, models, method='pseudo-BMA')['weight']
89+
90+
assert(w_st[0] > w_st[1] > w_st[2])
91+
assert(w_bb_bma[0] > w_bb_bma[1] > w_bb_bma[2])
92+
assert(w_bma[0] > w_bma[1] > w_bma[2])
93+
94+
assert_almost_equal(np.sum(w_st), 1.)
95+
assert_almost_equal(np.sum(w_st), 1.)
96+
assert_almost_equal(np.sum(w_st), 1.)
97+
98+
5099
class TestStats(SeededTest):
51100
@classmethod
52101
def setup_class(cls):

0 commit comments

Comments
 (0)