Skip to content

Commit 99dbb38

Browse files
committed
Refactor test_zsn_fail_axis
1 parent 44b5b91 commit 99dbb38

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

pymc/tests/distributions/test_multivariate.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from aeppl.logprob import ParameterValueError
2929
from aesara.tensor import TensorVariable
3030
from aesara.tensor.random.utils import broadcast_params
31-
from numpy import AxisError
3231

3332
import pymc as pm
3433

@@ -1442,21 +1441,29 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
14421441
).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
14431442

14441443
@pytest.mark.parametrize(
1445-
"dims, zerosum_axes",
1444+
"error, match, shape, support_shape, zerosum_axes",
14461445
[
1447-
(("regions", "answers"), 2),
1448-
(("regions", "answers"), (0, -2)),
1446+
(IndexError, "index out of range", (3, 4, 5), None, 4),
1447+
(AssertionError, "does not match", (3, 4), 3, None), # support_shape should be 4
1448+
(
1449+
AssertionError,
1450+
"does not match",
1451+
(3, 4),
1452+
(3, 4),
1453+
None,
1454+
), # doesn't work because zerosum_axes = 1
14491455
],
14501456
)
1451-
def test_zsn_fail_axis(self, dims, zerosum_axes):
1452-
if isinstance(zerosum_axes, (list, tuple)):
1453-
with pytest.raises(ValueError, match="repeated axis"):
1454-
with pm.Model(coords=COORDS) as m:
1455-
_ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
1456-
else:
1457-
with pytest.raises(AxisError, match="out of bounds"):
1458-
with pm.Model(coords=COORDS) as m:
1459-
_ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
1457+
def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes):
1458+
with pytest.raises(error, match=match):
1459+
with pm.Model() as m:
1460+
_ = pm.ZeroSumNormal(
1461+
"v", shape=shape, support_shape=support_shape, zerosum_axes=zerosum_axes
1462+
)
1463+
1464+
# v = pm.ZeroSumNormal("v", support_shape=(3, 4), zerosum_axes=2) # should work
1465+
1466+
# v = pm.ZeroSumNormal("v", shape=(3, 4), support_shape=(3, 4), zerosum_axes=2) this should work but doesn't
14601467

14611468
@pytest.mark.parametrize(
14621469
"zerosum_axes",

0 commit comments

Comments
 (0)