@@ -44,8 +44,8 @@ def reference_idata():
44
44
with model :
45
45
idata = pmx .fit (
46
46
method = "pathfinder" ,
47
- num_paths = 50 ,
48
- jitter = 10 .0 ,
47
+ num_paths = 10 ,
48
+ jitter = 12 .0 ,
49
49
random_seed = 41 ,
50
50
inference_backend = "pymc" ,
51
51
)
@@ -62,15 +62,15 @@ def test_pathfinder(inference_backend, reference_idata):
62
62
with model :
63
63
idata = pmx .fit (
64
64
method = "pathfinder" ,
65
- num_paths = 50 ,
66
- jitter = 10 .0 ,
65
+ num_paths = 10 ,
66
+ jitter = 12 .0 ,
67
67
random_seed = 41 ,
68
68
inference_backend = inference_backend ,
69
69
)
70
70
else :
71
71
idata = reference_idata
72
- np .testing .assert_allclose (idata .posterior ["mu" ].mean (), 5.0 , atol = 1.6 )
73
- np .testing .assert_allclose (idata .posterior ["tau" ].mean (), 4.15 , atol = 1.5 )
72
+ np .testing .assert_allclose (idata .posterior ["mu" ].mean (), 5.0 , atol = 0.95 )
73
+ np .testing .assert_allclose (idata .posterior ["tau" ].mean (), 4.15 , atol = 1.35 )
74
74
75
75
assert idata .posterior ["mu" ].shape == (1 , 1000 )
76
76
assert idata .posterior ["tau" ].shape == (1 , 1000 )
@@ -83,8 +83,8 @@ def test_concurrent_results(reference_idata, concurrent):
83
83
with model :
84
84
idata_conc = pmx .fit (
85
85
method = "pathfinder" ,
86
- num_paths = 50 ,
87
- jitter = 10 .0 ,
86
+ num_paths = 10 ,
87
+ jitter = 12 .0 ,
88
88
random_seed = 41 ,
89
89
inference_backend = "pymc" ,
90
90
concurrent = concurrent ,
@@ -108,15 +108,15 @@ def test_seed(reference_idata):
108
108
with model :
109
109
idata_41 = pmx .fit (
110
110
method = "pathfinder" ,
111
- num_paths = 50 ,
111
+ num_paths = 4 ,
112
112
jitter = 10.0 ,
113
113
random_seed = 41 ,
114
114
inference_backend = "pymc" ,
115
115
)
116
116
117
117
idata_123 = pmx .fit (
118
118
method = "pathfinder" ,
119
- num_paths = 50 ,
119
+ num_paths = 4 ,
120
120
jitter = 10.0 ,
121
121
random_seed = 123 ,
122
122
inference_backend = "pymc" ,
@@ -171,3 +171,33 @@ def test_bfgs_sample():
171
171
assert gamma .eval ().shape == (L , 2 * J , 2 * J )
172
172
assert phi .eval ().shape == (L , num_samples , N )
173
173
assert logq .eval ().shape == (L , num_samples )
174
+
175
+
176
+ @pytest .mark .parametrize ("importance_sampling" , ["psis" , "psir" , "identity" , None ])
177
+ def test_pathfinder_importance_sampling (importance_sampling ):
178
+ model = eight_schools_model ()
179
+
180
+ num_paths = 4
181
+ num_draws_per_path = 300
182
+ num_draws = 750
183
+
184
+ with model :
185
+ idata = pmx .fit (
186
+ method = "pathfinder" ,
187
+ num_paths = num_paths ,
188
+ num_draws_per_path = num_draws_per_path ,
189
+ num_draws = num_draws ,
190
+ maxiter = 5 ,
191
+ random_seed = 41 ,
192
+ inference_backend = "pymc" ,
193
+ importance_sampling = importance_sampling ,
194
+ )
195
+
196
+ if importance_sampling is None :
197
+ assert idata .posterior ["mu" ].shape == (num_paths , num_draws_per_path )
198
+ assert idata .posterior ["tau" ].shape == (num_paths , num_draws_per_path )
199
+ assert idata .posterior ["theta" ].shape == (num_paths , num_draws_per_path , 8 )
200
+ else :
201
+ assert idata .posterior ["mu" ].shape == (1 , num_draws )
202
+ assert idata .posterior ["tau" ].shape == (1 , num_draws )
203
+ assert idata .posterior ["theta" ].shape == (1 , num_draws , 8 )
0 commit comments