@@ -47,6 +47,55 @@ def test_log_post_trace():
47
47
npt .assert_allclose (logp , - 0.5 * np .log (2 * np .pi ), atol = 1e-7 )
48
48
49
49
50
+ def test_compare ():
51
+ np .random .seed (42 )
52
+ x_obs = np .random .normal (0 , 1 , size = 100 )
53
+
54
+ with pm .Model () as model0 :
55
+ mu = pm .Normal ('mu' , 0 , 1 )
56
+ x = pm .Normal ('x' , mu = mu , sd = 1 , observed = x_obs )
57
+ trace0 = pm .sample (1000 )
58
+
59
+ with pm .Model () as model1 :
60
+ mu = pm .Normal ('mu' , 0 , 1 )
61
+ x = pm .Normal ('x' , mu = mu , sd = 0.8 , observed = x_obs )
62
+ trace1 = pm .sample (1000 )
63
+
64
+ with pm .Model () as model2 :
65
+ mu = pm .Normal ('mu' , 0 , 1 )
66
+ x = pm .StudentT ('x' , nu = 1 , mu = mu , lam = 1 , observed = x_obs )
67
+ trace2 = pm .sample (1000 )
68
+
69
+ traces = [trace0 ] * 2
70
+ models = [model0 ] * 2
71
+
72
+ w_st = pm .compare (traces , models , method = 'stacking' )['weight' ]
73
+ w_bb_bma = pm .compare (traces , models , method = 'BB-pseudo-BMA' )['weight' ]
74
+ w_bma = pm .compare (traces , models , method = 'pseudo-BMA' )['weight' ]
75
+
76
+ assert_almost_equal (w_st [0 ], w_st [1 ])
77
+ assert_almost_equal (w_bb_bma [0 ], w_bb_bma [1 ])
78
+ assert_almost_equal (w_bma [0 ], w_bma [1 ])
79
+
80
+ assert_almost_equal (np .sum (w_st ), 1. )
81
+ assert_almost_equal (np .sum (w_bb_bma ), 1. )
82
+ assert_almost_equal (np .sum (w_bma ), 1. )
83
+
84
+ traces = [trace0 , trace1 , trace2 ]
85
+ models = [model0 , model1 , model2 ]
86
+ w_st = pm .compare (traces , models , method = 'stacking' )['weight' ]
87
+ w_bb_bma = pm .compare (traces , models , method = 'BB-pseudo-BMA' )['weight' ]
88
+ w_bma = pm .compare (traces , models , method = 'pseudo-BMA' )['weight' ]
89
+
90
+ assert (w_st [0 ] > w_st [1 ] > w_st [2 ])
91
+ assert (w_bb_bma [0 ] > w_bb_bma [1 ] > w_bb_bma [2 ])
92
+ assert (w_bma [0 ] > w_bma [1 ] > w_bma [2 ])
93
+
94
+ assert_almost_equal (np .sum (w_st ), 1. )
95
+ assert_almost_equal (np .sum (w_st ), 1. )
96
+ assert_almost_equal (np .sum (w_st ), 1. )
97
+
98
+
50
99
class TestStats (SeededTest ):
51
100
@classmethod
52
101
def setup_class (cls ):
0 commit comments