Skip to content

Commit 09f0d91

Browse files
committed
Simplify test_zsn_dims_shape
1 parent e3dc1d4 commit 09f0d91

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

pymc/tests/distributions/test_multivariate.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,27 +1418,15 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
14181418

14191419
assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"]))
14201420

1421-
zerosum_axes = np.arange(-v.owner.op.ndim_supp, 0)
1422-
nonzero_axes = np.arange(v.ndim - v.owner.op.ndim_supp)
1423-
1424-
for ax in zerosum_axes:
1425-
for samples in [
1426-
s.posterior.v.mean(axis=ax),
1427-
random_samples.mean(axis=ax),
1428-
]:
1429-
assert np.isclose(
1430-
samples, 0
1431-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1432-
1433-
if nonzero_axes:
1434-
for ax in nonzero_axes:
1435-
for samples in [
1436-
s.posterior.v.mean(axis=ax),
1437-
random_samples.mean(axis=ax),
1438-
]:
1439-
assert not np.isclose(
1440-
samples, 0
1441-
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1421+
ndim_supp = v.owner.op.ndim_supp
1422+
zerosum_axes = np.arange(-ndim_supp, 0)
1423+
nonzero_axes = np.arange(v.ndim - ndim_supp)
1424+
for samples in [
1425+
s.posterior.v,
1426+
random_samples,
1427+
]:
1428+
self.assert_zerosum_axes(samples, zerosum_axes)
1429+
self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False)
14421430

14431431
@pytest.mark.parametrize(
14441432
"error, match, shape, support_shape, zerosum_axes",
@@ -1473,6 +1461,7 @@ def test_zsn_change_dist_size(self, zerosum_axes):
14731461
base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), zerosum_axes=zerosum_axes)
14741462
random_samples = pm.draw(base_dist, draws=100)
14751463

1464+
zerosum_axes = np.arange(-zerosum_axes, 0)
14761465
self.assert_zerosum_axes(random_samples, zerosum_axes)
14771466

14781467
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False)
@@ -1488,12 +1477,17 @@ def test_zsn_change_dist_size(self, zerosum_axes):
14881477
random_samples = pm.draw(new_dist, draws=100)
14891478
self.assert_zerosum_axes(random_samples, zerosum_axes)
14901479

1491-
def assert_zerosum_axes(self, random_samples, zerosum_axes):
1492-
zerosum_axes = np.arange(-zerosum_axes, 0)
1493-
for ax in zerosum_axes:
1494-
assert np.isclose(
1495-
random_samples.mean(axis=ax), 0
1496-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1480+
def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True):
1481+
if check_zerosum_axes:
1482+
for ax in axes_to_check:
1483+
assert np.isclose(
1484+
random_samples.mean(axis=ax), 0
1485+
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1486+
else:
1487+
for ax in axes_to_check:
1488+
assert not np.isclose(
1489+
random_samples.mean(axis=ax), 0
1490+
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
14971491

14981492

14991493
class TestMvStudentTCov(BaseTestDistributionRandom):

0 commit comments

Comments
 (0)