Skip to content

Commit 3823b81

Browse files
committed
test: Add pathfinder importance sampling test cases
- Test different importance sampling approaches (psis, psir, identity, None)
1 parent f63d5ce commit 3823b81

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

tests/test_pathfinder.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def reference_idata():
4444
with model:
4545
idata = pmx.fit(
4646
method="pathfinder",
47-
num_paths=50,
48-
jitter=10.0,
47+
num_paths=10,
48+
jitter=12.0,
4949
random_seed=41,
5050
inference_backend="pymc",
5151
)
@@ -62,15 +62,15 @@ def test_pathfinder(inference_backend, reference_idata):
6262
with model:
6363
idata = pmx.fit(
6464
method="pathfinder",
65-
num_paths=50,
66-
jitter=10.0,
65+
num_paths=10,
66+
jitter=12.0,
6767
random_seed=41,
6868
inference_backend=inference_backend,
6969
)
7070
else:
7171
idata = reference_idata
72-
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6)
73-
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5)
72+
np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=0.95)
73+
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.35)
7474

7575
assert idata.posterior["mu"].shape == (1, 1000)
7676
assert idata.posterior["tau"].shape == (1, 1000)
@@ -83,8 +83,8 @@ def test_concurrent_results(reference_idata, concurrent):
8383
with model:
8484
idata_conc = pmx.fit(
8585
method="pathfinder",
86-
num_paths=50,
87-
jitter=10.0,
86+
num_paths=10,
87+
jitter=12.0,
8888
random_seed=41,
8989
inference_backend="pymc",
9090
concurrent=concurrent,
@@ -108,15 +108,15 @@ def test_seed(reference_idata):
108108
with model:
109109
idata_41 = pmx.fit(
110110
method="pathfinder",
111-
num_paths=50,
111+
num_paths=4,
112112
jitter=10.0,
113113
random_seed=41,
114114
inference_backend="pymc",
115115
)
116116

117117
idata_123 = pmx.fit(
118118
method="pathfinder",
119-
num_paths=50,
119+
num_paths=4,
120120
jitter=10.0,
121121
random_seed=123,
122122
inference_backend="pymc",
@@ -171,3 +171,33 @@ def test_bfgs_sample():
171171
assert gamma.eval().shape == (L, 2 * J, 2 * J)
172172
assert phi.eval().shape == (L, num_samples, N)
173173
assert logq.eval().shape == (L, num_samples)
174+
175+
176+
@pytest.mark.parametrize("importance_sampling", ["psis", "psir", "identity", None])
177+
def test_pathfinder_importance_sampling(importance_sampling):
178+
model = eight_schools_model()
179+
180+
num_paths = 4
181+
num_draws_per_path = 300
182+
num_draws = 750
183+
184+
with model:
185+
idata = pmx.fit(
186+
method="pathfinder",
187+
num_paths=num_paths,
188+
num_draws_per_path=num_draws_per_path,
189+
num_draws=num_draws,
190+
maxiter=5,
191+
random_seed=41,
192+
inference_backend="pymc",
193+
importance_sampling=importance_sampling,
194+
)
195+
196+
if importance_sampling is None:
197+
assert idata.posterior["mu"].shape == (num_paths, num_draws_per_path)
198+
assert idata.posterior["tau"].shape == (num_paths, num_draws_per_path)
199+
assert idata.posterior["theta"].shape == (num_paths, num_draws_per_path, 8)
200+
else:
201+
assert idata.posterior["mu"].shape == (1, num_draws)
202+
assert idata.posterior["tau"].shape == (1, num_draws)
203+
assert idata.posterior["theta"].shape == (1, num_draws, 8)

0 commit comments

Comments
 (0)