Skip to content

Commit 6612a24

Browse files
committed
Increase tolerance for test_zsn_variance
1 parent 48dafe9 commit 6612a24

File tree

2 files changed

+30
-23
lines changed

2 files changed

+30
-23
lines changed

pymc/distributions/multivariate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2438,15 +2438,15 @@ class ZeroSumNormal(Distribution):
24382438
}
24392439
with pm.Model(coords=COORDS) as m:
24402440
# the zero sum axis will be 'answers'
2441-
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
2441+
v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
24422442
24432443
with pm.Model(coords=COORDS) as m:
24442444
# the zero sum axes will be 'answers' and 'regions'
2445-
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2)
2445+
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2)
24462446
24472447
with pm.Model(coords=COORDS) as m:
24482448
# the zero sum axes will be the last two
2449-
...: v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2)
2449+
v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2)
24502450
"""
24512451
rv_type = ZeroSumNormalRV
24522452

pymc/tests/distributions/test_multivariate.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,13 +1381,12 @@ def test_issue_3706(self):
13811381
assert prior_pred["X"].shape == (1, N, 2)
13821382

13831383

1384-
COORDS = {
1385-
"regions": ["a", "b", "c"],
1386-
"answers": ["yes", "no", "whatever", "don't understand question"],
1387-
}
1388-
1389-
13901384
class TestZeroSumNormal:
1385+
coords = {
1386+
"regions": ["a", "b", "c"],
1387+
"answers": ["yes", "no", "whatever", "don't understand question"],
1388+
}
1389+
13911390
def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True):
13921391
if check_zerosum_axes:
13931392
for ax in axes_to_check:
@@ -1409,14 +1408,19 @@ def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=
14091408
],
14101409
)
14111410
def test_zsn_dims(self, dims, zerosum_axes):
1412-
with pm.Model(coords=COORDS) as m:
1411+
with pm.Model(coords=self.coords) as m:
14131412
v = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
14141413
s = pm.sample(10, chains=1, tune=100)
14151414

14161415
# to test forward graph
14171416
random_samples = pm.draw(v, draws=10)
14181417

1419-
assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"]))
1418+
assert s.posterior.v.shape == (
1419+
1,
1420+
10,
1421+
len(self.coords["regions"]),
1422+
len(self.coords["answers"]),
1423+
)
14201424

14211425
ndim_supp = v.owner.op.ndim_supp
14221426
zerosum_axes = np.arange(-ndim_supp, 0)
@@ -1429,22 +1433,25 @@ def test_zsn_dims(self, dims, zerosum_axes):
14291433
self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False)
14301434

14311435
@pytest.mark.parametrize(
1432-
"zerosum_axes, shape",
1433-
[
1434-
(None, (len(COORDS["regions"]), len(COORDS["answers"]))),
1435-
(1, (len(COORDS["regions"]), len(COORDS["answers"]))),
1436-
(2, (len(COORDS["regions"]), len(COORDS["answers"]))),
1437-
],
1436+
"zerosum_axes",
1437+
(None, 1, 2),
14381438
)
1439-
def test_zsn_shape(self, shape, zerosum_axes):
1440-
with pm.Model(coords=COORDS) as m:
1439+
def test_zsn_shape(self, zerosum_axes):
1440+
shape = (len(self.coords["regions"]), len(self.coords["answers"]))
1441+
1442+
with pm.Model(coords=self.coords) as m:
14411443
v = pm.ZeroSumNormal("v", shape=shape, zerosum_axes=zerosum_axes)
14421444
s = pm.sample(10, chains=1, tune=100)
14431445

14441446
# to test forward graph
14451447
random_samples = pm.draw(v, draws=10)
14461448

1447-
assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"]))
1449+
assert s.posterior.v.shape == (
1450+
1,
1451+
10,
1452+
len(self.coords["regions"]),
1453+
len(self.coords["answers"]),
1454+
)
14481455

14491456
ndim_supp = v.owner.op.ndim_supp
14501457
zerosum_axes = np.arange(-ndim_supp, 0)
@@ -1525,13 +1532,13 @@ def test_zsn_change_dist_size(self, zerosum_axes):
15251532
)
15261533
def test_zsn_variance(self, sigma, n):
15271534

1528-
dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=n)
1529-
random_samples = pm.draw(dist, draws=100_000)
1535+
dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=(100_000, n))
1536+
random_samples = pm.draw(dist)
15301537

15311538
empirical_var = random_samples.var(axis=0)
15321539
theoretical_var = sigma**2 * (n - 1) / n
15331540

1534-
np.testing.assert_allclose(empirical_var, theoretical_var, rtol=1e-02)
1541+
np.testing.assert_allclose(empirical_var, theoretical_var, atol=0.4)
15351542

15361543
@pytest.mark.parametrize(
15371544
"sigma, shape, zerosum_axes, mvn_axes",

0 commit comments

Comments
 (0)