Skip to content

Commit ce68f02

Browse files
committed
Test support_shape handling
1 parent 3e86a3e commit ce68f02

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

pymc/tests/distributions/test_multivariate.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,14 +1432,14 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
14321432
"error, match, shape, support_shape, zerosum_axes",
14331433
[
14341434
(IndexError, "index out of range", (3, 4, 5), None, 4),
1435-
(AssertionError, "does not match", (3, 4), 3, None), # support_shape should be 4
1435+
(AssertionError, "does not match", (3, 4), (3,), None), # support_shape should be 4
14361436
(
14371437
AssertionError,
14381438
"does not match",
14391439
(3, 4),
14401440
(3, 4),
14411441
None,
1442-
), # doesn't work because zerosum_axes = 1
1442+
), # doesn't work because zerosum_axes = 1 by default
14431443
],
14441444
)
14451445
def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
@@ -1449,9 +1449,20 @@ def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
14491449
"v", shape=shape, support_shape=support_shape, zerosum_axes=zerosum_axes
14501450
)
14511451

1452-
# v = pm.ZeroSumNormal("v", support_shape=(3, 4), zerosum_axes=2) # should work
1452+
@pytest.mark.parametrize(
1453+
"shape, support_shape",
1454+
[
1455+
(None, (3, 4)),
1456+
((3, 4), (3, 4)),
1457+
],
1458+
)
1459+
def test_zsn_support_shape(self, shape, support_shape):
1460+
with pm.Model() as m:
1461+
v = pm.ZeroSumNormal("v", shape=shape, support_shape=support_shape, zerosum_axes=2)
14531462

1454-
# v = pm.ZeroSumNormal("v", shape=(3, 4), support_shape=(3, 4), zerosum_axes=2) this should work but doesn't
1463+
random_samples = pm.draw(v, draws=10)
1464+
zerosum_axes = np.arange(-2, 0)
1465+
self.assert_zerosum_axes(random_samples, zerosum_axes)
14551466

14561467
@pytest.mark.parametrize(
14571468
"zerosum_axes",
@@ -1465,9 +1476,9 @@ def test_zsn_change_dist_size(self, zerosum_axes):
14651476
self.assert_zerosum_axes(random_samples, zerosum_axes)
14661477

14671478
new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False)
1468-
if zerosum_axes == 1:
1479+
try:
14691480
assert new_dist.eval().shape == (5, 3, 9)
1470-
elif zerosum_axes == 2:
1481+
except AssertionError:
14711482
assert new_dist.eval().shape == (5, 3, 4, 9)
14721483
random_samples = pm.draw(new_dist, draws=100)
14731484
self.assert_zerosum_axes(random_samples, zerosum_axes)

0 commit comments

Comments
 (0)