|
28 | 28 | from aeppl.logprob import ParameterValueError
|
29 | 29 | from aesara.tensor import TensorVariable
|
30 | 30 | from aesara.tensor.random.utils import broadcast_params
|
31 |
| -from numpy import AxisError |
32 | 31 |
|
33 | 32 | import pymc as pm
|
34 | 33 |
|
@@ -1442,21 +1441,29 @@ def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
|
1442 | 1441 | ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples."
|
1443 | 1442 |
|
1444 | 1443 | @pytest.mark.parametrize(
|
1445 |
| - "dims, zerosum_axes", |
| 1444 | + "error, match, shape, support_shape, zerosum_axes", |
1446 | 1445 | [
|
1447 |
| - (("regions", "answers"), 2), |
1448 |
| - (("regions", "answers"), (0, -2)), |
| 1446 | + (IndexError, "index out of range", (3, 4, 5), None, 4), |
| 1447 | + (AssertionError, "does not match", (3, 4), 3, None), # support_shape should be 4 |
| 1448 | + ( |
| 1449 | + AssertionError, |
| 1450 | + "does not match", |
| 1451 | + (3, 4), |
| 1452 | + (3, 4), |
| 1453 | + None, |
| 1454 | + ), # doesn't work because zerosum_axes = 1 |
1449 | 1455 | ],
|
1450 | 1456 | )
|
1451 |
| - def test_zsn_fail_axis(self, dims, zerosum_axes): |
1452 |
| - if isinstance(zerosum_axes, (list, tuple)): |
1453 |
| - with pytest.raises(ValueError, match="repeated axis"): |
1454 |
| - with pm.Model(coords=COORDS) as m: |
1455 |
| - _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) |
1456 |
| - else: |
1457 |
| - with pytest.raises(AxisError, match="out of bounds"): |
1458 |
| - with pm.Model(coords=COORDS) as m: |
1459 |
| - _ = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) |
| 1457 | + def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes): |
| 1458 | + with pytest.raises(error, match=match): |
| 1459 | + with pm.Model() as m: |
| 1460 | + _ = pm.ZeroSumNormal( |
| 1461 | + "v", shape=shape, support_shape=support_shape, zerosum_axes=zerosum_axes |
| 1462 | + ) |
| 1463 | + |
| 1464 | + # v = pm.ZeroSumNormal("v", support_shape=(3, 4), zerosum_axes=2) # should work |
| 1465 | + |
| 1466 | + # v = pm.ZeroSumNormal("v", shape=(3, 4), support_shape=(3, 4), zerosum_axes=2) this should work but doesn't |
1460 | 1467 |
|
1461 | 1468 | @pytest.mark.parametrize(
|
1462 | 1469 | "zerosum_axes",
|
|
0 commit comments