Skip to content

Commit 44b5b91

Browse files
committed
Refactor test_zsn_dims_shape
1 parent 7e4ed0a commit 44b5b91

File tree

1 file changed

+23
-37
lines changed

1 file changed

+23
-37
lines changed

pymc/tests/distributions/test_multivariate.py

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,12 +1401,12 @@ class TestZeroSumNormal:
14011401
@pytest.mark.parametrize(
14021402
"dims, zerosum_axes, shape",
14031403
[
1404-
(("regions", "answers"), "answers", None),
1405-
(("regions", "answers"), ("regions", "answers"), None),
1406-
(("regions", "answers"), 0, None),
1407-
(("regions", "answers"), -1, None),
1408-
(("regions", "answers"), (0, 1), None),
1409-
(None, -2, (len(COORDS["regions"]), len(COORDS["answers"]))),
1404+
(("regions", "answers"), None, None),
1405+
(("regions", "answers"), 1, None),
1406+
(("regions", "answers"), 2, None),
1407+
(None, None, (len(COORDS["regions"]), len(COORDS["answers"]))),
1408+
(None, 1, (len(COORDS["regions"]), len(COORDS["answers"]))),
1409+
(None, 2, (len(COORDS["regions"]), len(COORDS["answers"]))),
14101410
],
14111411
)
14121412
def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
@@ -1419,41 +1419,27 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
14191419

14201420
assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"]))
14211421

1422-
if not isinstance(zerosum_axes, (list, tuple)):
1423-
zerosum_axes = [zerosum_axes]
1422+
zerosum_axes = np.arange(-v.owner.op.ndim_supp, 0)
1423+
nonzero_axes = np.arange(v.ndim - v.owner.op.ndim_supp)
1424+
1425+
for ax in zerosum_axes:
1426+
for samples in [
1427+
s.posterior.v.mean(axis=ax),
1428+
random_samples.mean(axis=ax),
1429+
]:
1430+
assert np.isclose(
1431+
samples, 0
1432+
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
14241433

1425-
if isinstance(zerosum_axes[0], str):
1426-
for ax in zerosum_axes:
1434+
if nonzero_axes:
1435+
for ax in nonzero_axes:
14271436
for samples in [
1428-
s.posterior.v.mean(dim=ax),
1429-
random_samples.mean(axis=dims.index(ax) + 1),
1437+
s.posterior.v.mean(axis=ax),
1438+
random_samples.mean(axis=ax),
14301439
]:
1431-
assert np.isclose(
1440+
assert not np.isclose(
14321441
samples, 0
1433-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1434-
1435-
nonzero_axes = list(set(dims).difference(zerosum_axes))
1436-
if nonzero_axes:
1437-
for ax in nonzero_axes:
1438-
for samples in [
1439-
s.posterior.v.mean(dim=ax),
1440-
random_samples.mean(axis=dims.index(ax) + 1),
1441-
]:
1442-
assert not np.isclose(
1443-
samples, 0
1444-
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
1445-
1446-
else:
1447-
for ax in zerosum_axes:
1448-
if ax < 0:
1449-
assert np.isclose(
1450-
s.posterior.v.mean(axis=ax), 0
1451-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1452-
else:
1453-
ax = ax + 2 # because 'chain' and 'draw' are added as new axes after sampling
1454-
assert np.isclose(
1455-
s.posterior.v.mean(axis=ax), 0
1456-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1442+
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
14571443

14581444
@pytest.mark.parametrize(
14591445
"dims, zerosum_axes",

0 commit comments

Comments
 (0)