Skip to content

Commit dec4a9f

Browse files
committed
Improve test_zsn_change_dist_size
1 parent e94e4f1 commit dec4a9f

File tree

1 file changed

+28
-6
lines changed

1 file changed

+28
-6
lines changed

pymc/tests/distributions/test_multivariate.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,10 +1415,7 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
14151415
s = pm.sample(10, chains=1, tune=100)
14161416

14171417
# to test forward graph
1418-
random_samples = pm.draw(
1419-
v,
1420-
draws=10,
1421-
)
1418+
random_samples = pm.draw(v, draws=10)
14221419

14231420
assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"]))
14241421

@@ -1475,14 +1472,39 @@ def test_zsn_fail_axis(self, dims, zerosum_axes):
14751472
with pm.Model(coords=COORDS) as m:
14761473
_ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
14771474

1478-
def test_zsn_change_dist_size(self):
1479-
base_dist = pm.ZeroSumNormal.dist(shape=(4, 9))
1475+
@pytest.mark.parametrize(
1476+
"zerosum_axes",
1477+
[(-1), (-2), (1), ((0, 1)), ((-2, -1))],
1478+
)
1479+
def test_zsn_change_dist_size(self, zerosum_axes):
1480+
base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), zerosum_axes=zerosum_axes)
1481+
random_samples = pm.draw(base_dist, draws=100)
1482+
1483+
if not isinstance(zerosum_axes, (list, tuple)):
1484+
zerosum_axes = [zerosum_axes]
1485+
self.assert_zerosum_axes(random_samples, zerosum_axes)
14801486

14811487
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False)
14821488
assert new_dist.eval().shape == (5, 3)
1489+
random_samples = pm.draw(new_dist, draws=100)
1490+
self.assert_zerosum_axes(random_samples, zerosum_axes)
14831491

14841492
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True)
14851493
assert new_dist.eval().shape == (5, 3, 4, 9)
1494+
random_samples = pm.draw(new_dist, draws=100)
1495+
self.assert_zerosum_axes(random_samples, zerosum_axes)
1496+
1497+
def assert_zerosum_axes(self, random_samples, zerosum_axes):
1498+
for ax in zerosum_axes:
1499+
if ax < 0:
1500+
assert np.isclose(
1501+
random_samples.mean(axis=ax), 0
1502+
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1503+
else:
1504+
ax = ax + 1
1505+
assert np.isclose(
1506+
random_samples.mean(axis=ax), 0
1507+
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
14861508

14871509

14881510
class TestMvStudentTCov(BaseTestDistributionRandom):

0 commit comments

Comments
 (0)