Skip to content

Commit 1f29284

Browse files
Test fixes
1 parent 27fd53d commit 1f29284

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

tests/distributions/test_multivariate.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,11 +1471,15 @@ def test_with_lkjcorr_matrix(
14711471
prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False)
14721472

14731473
assert prior["corr_mat"].shape == (10, 3, 3) # square
1474-
assert np.allclose(prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]], 1.0) # 1.0 on diagonal
14751474
assert (prior["corr_mat"] == prior["corr_mat"].transpose(0, 2, 1)).all() # symmetric
1476-
assert (
1477-
prior["corr_mat"].max() <= 1.0 and prior["corr_mat"].min() >= -1.0
1478-
) # constrained between -1 and 1
1475+
1476+
np.testing.assert_allclose(
1477+
prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]], 1.0
1478+
) # 1.0 on diagonal
1479+
1480+
# constrained between -1 and 1
1481+
assert prior["corr_mat"].max() <= (1.0 + 1e-12)
1482+
assert prior["corr_mat"].min() >= (-1.0 - 1e-12)
14791483

14801484
def test_issue_3758(self):
14811485
np.random.seed(42)
@@ -2172,8 +2176,6 @@ class TestLKJCorr(BaseTestDistributionRandom):
21722176
]
21732177

21742178
def check_draws_match_expected(self):
2175-
from pymc.distributions import CustomDist
2176-
21772179
def ref_rand(size, n, eta):
21782180
shape = int(n * (n - 1) // 2)
21792181
beta = eta - 1 + n / 2
@@ -2182,16 +2184,9 @@ def ref_rand(size, n, eta):
21822184

21832185
# If passed as a domain, continuous_random_tester would make `n` a shared variable
21842186
# But this RV needs it to be constant in order to define the inner graph
2185-
def lkj_corr_tril(n, eta, shape=None):
2186-
tril_idx = pt.tril_indices(n)
2187-
return _LKJCorr.dist(n=n, eta=eta, shape=shape)[..., tril_idx[0], tril_idx[1]]
2188-
2189-
def SlicedLKJ(name, n, eta, *args, shape=None, **kwargs):
2190-
return CustomDist(name, n, eta, dist=lkj_corr_tril, shape=shape)
2191-
21922187
for n in (2, 10, 50):
21932188
continuous_random_tester(
2194-
SlicedLKJ,
2189+
_LKJCorr,
21952190
{
21962191
"eta": Domain([1.0, 10.0, 100.0], edges=(None, None)),
21972192
},
@@ -2204,7 +2199,7 @@ def SlicedLKJ(name, n, eta, *args, shape=None, **kwargs):
22042199
@pytest.mark.parametrize("shape", [(2, 2), (3, 2, 2)], ids=["no_batch", "with_batch"])
22052200
def test_LKJCorr_default_transform(shape):
22062201
with pm.Model() as m:
2207-
x = pm.LKJCorr("x", n=2, eta=1, shape=shape, return_matrix=False)
2202+
x = pm.LKJCorr("x", n=2, eta=1, shape=shape)
22082203
assert isinstance(m.rvs_to_transforms[x], CholeskyCorrTransform)
22092204
assert m.logp(sum=False)[0].type.shape == shape[:-2]
22102205

0 commit comments

Comments
 (0)