Skip to content

Commit ca655bc

Browse files
committed
Split dims and shape test
1 parent 13a54e6 commit ca655bc

File tree

1 file changed

+34
-9
lines changed

1 file changed

+34
-9
lines changed

pymc/tests/distributions/test_multivariate.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,19 +1399,44 @@ def test_issue_3706(self):
13991399

14001400
class TestZeroSumNormal:
14011401
@pytest.mark.parametrize(
1402-
"dims, zerosum_axes, shape",
1402+
"dims, zerosum_axes",
14031403
[
1404-
(("regions", "answers"), None, None),
1405-
(("regions", "answers"), 1, None),
1406-
(("regions", "answers"), 2, None),
1407-
(None, None, (len(COORDS["regions"]), len(COORDS["answers"]))),
1408-
(None, 1, (len(COORDS["regions"]), len(COORDS["answers"]))),
1409-
(None, 2, (len(COORDS["regions"]), len(COORDS["answers"]))),
1404+
(("regions", "answers"), None),
1405+
(("regions", "answers"), 1),
1406+
(("regions", "answers"), 2),
14101407
],
14111408
)
1412-
def test_zsn_dims_shape(self, dims, zerosum_axes, shape):
1409+
def test_zsn_dims(self, dims, zerosum_axes):
14131410
with pm.Model(coords=COORDS) as m:
1414-
v = pm.ZeroSumNormal("v", dims=dims, shape=shape, zerosum_axes=zerosum_axes)
1411+
v = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes)
1412+
s = pm.sample(10, chains=1, tune=100)
1413+
1414+
# to test forward graph
1415+
random_samples = pm.draw(v, draws=10)
1416+
1417+
assert s.posterior.v.shape == (1, 10, len(COORDS["regions"]), len(COORDS["answers"]))
1418+
1419+
ndim_supp = v.owner.op.ndim_supp
1420+
zerosum_axes = np.arange(-ndim_supp, 0)
1421+
nonzero_axes = np.arange(v.ndim - ndim_supp)
1422+
for samples in [
1423+
s.posterior.v,
1424+
random_samples,
1425+
]:
1426+
self.assert_zerosum_axes(samples, zerosum_axes)
1427+
self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False)
1428+
1429+
@pytest.mark.parametrize(
1430+
"zerosum_axes, shape",
1431+
[
1432+
(None, (len(COORDS["regions"]), len(COORDS["answers"]))),
1433+
(1, (len(COORDS["regions"]), len(COORDS["answers"]))),
1434+
(2, (len(COORDS["regions"]), len(COORDS["answers"]))),
1435+
],
1436+
)
1437+
def test_zsn_shape(self, shape, zerosum_axes):
1438+
with pm.Model(coords=COORDS) as m:
1439+
v = pm.ZeroSumNormal("v", shape=shape, zerosum_axes=zerosum_axes)
14151440
s = pm.sample(10, chains=1, tune=100)
14161441

14171442
# to test forward graph

0 commit comments

Comments
 (0)