Skip to content

Commit e3dc1d4

Browse files
committed
Refactor test_zsn_change_dist_size
1 parent 99dbb38 commit e3dc1d4

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

pymc/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2567,7 +2567,7 @@ def change_zerosum_size(op, normal_dist, new_size, expand=False):
25672567

25682568
if expand:
25692569
original_shape = tuple(normal_dist.shape)
2570-
old_size = original_shape[len(original_shape) - op.ndim_supp :]
2570+
old_size = original_shape[: len(original_shape) - op.ndim_supp]
25712571
new_size = tuple(new_size) + old_size
25722572

25732573
return ZeroSumNormal.rv_op(

pymc/tests/distributions/test_multivariate.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,18 +1467,19 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
14671467

14681468
@pytest.mark.parametrize(
14691469
"zerosum_axes",
1470-
[(-1), (-2), (1), ((0, 1)), ((-2, -1))],
1470+
[1, 2],
14711471
)
14721472
def test_zsn_change_dist_size(self, zerosum_axes):
14731473
base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), zerosum_axes=zerosum_axes)
14741474
random_samples = pm.draw(base_dist, draws=100)
14751475

1476-
if not isinstance(zerosum_axes, (list, tuple)):
1477-
zerosum_axes = [zerosum_axes]
14781476
self.assert_zerosum_axes(random_samples, zerosum_axes)
14791477

14801478
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False)
1481-
assert new_dist.eval().shape == (5, 3)
1479+
if zerosum_axes == 1:
1480+
assert new_dist.eval().shape == (5, 3, 9)
1481+
elif zerosum_axes == 2:
1482+
assert new_dist.eval().shape == (5, 3, 4, 9)
14821483
random_samples = pm.draw(new_dist, draws=100)
14831484
self.assert_zerosum_axes(random_samples, zerosum_axes)
14841485

@@ -1488,16 +1489,11 @@ def test_zsn_change_dist_size(self, zerosum_axes):
14881489
self.assert_zerosum_axes(random_samples, zerosum_axes)
14891490

14901491
def assert_zerosum_axes(self, random_samples, zerosum_axes):
1492+
zerosum_axes = np.arange(-zerosum_axes, 0)
14911493
for ax in zerosum_axes:
1492-
if ax < 0:
1493-
assert np.isclose(
1494-
random_samples.mean(axis=ax), 0
1495-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1496-
else:
1497-
ax = ax + 1
1498-
assert np.isclose(
1499-
random_samples.mean(axis=ax), 0
1500-
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
1494+
assert np.isclose(
1495+
random_samples.mean(axis=ax), 0
1496+
).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples."
15011497

15021498

15031499
class TestMvStudentTCov(BaseTestDistributionRandom):

0 commit comments

Comments
 (0)