@@ -1471,11 +1471,15 @@ def test_with_lkjcorr_matrix(
1471
1471
prior = pm .sample_prior_predictive (draws = 10 , return_inferencedata = False )
1472
1472
1473
1473
assert prior ["corr_mat" ].shape == (10 , 3 , 3 ) # square
1474
- assert np .allclose (prior ["corr_mat" ][:, [0 , 1 , 2 ], [0 , 1 , 2 ]], 1.0 ) # 1.0 on diagonal
1475
1474
assert (prior ["corr_mat" ] == prior ["corr_mat" ].transpose (0 , 2 , 1 )).all () # symmetric
1476
- assert (
1477
- prior ["corr_mat" ].max () <= 1.0 and prior ["corr_mat" ].min () >= - 1.0
1478
- ) # constrained between -1 and 1
1475
+
1476
+ np .testing .assert_allclose (
1477
+ prior ["corr_mat" ][:, [0 , 1 , 2 ], [0 , 1 , 2 ]], 1.0
1478
+ ) # 1.0 on diagonal
1479
+
1480
+ # constrained between -1 and 1
1481
+ assert prior ["corr_mat" ].max () <= (1.0 + 1e-12 )
1482
+ assert prior ["corr_mat" ].min () >= (- 1.0 - 1e-12 )
1479
1483
1480
1484
def test_issue_3758 (self ):
1481
1485
np .random .seed (42 )
@@ -2172,8 +2176,6 @@ class TestLKJCorr(BaseTestDistributionRandom):
2172
2176
]
2173
2177
2174
2178
def check_draws_match_expected (self ):
2175
- from pymc .distributions import CustomDist
2176
-
2177
2179
def ref_rand (size , n , eta ):
2178
2180
shape = int (n * (n - 1 ) // 2 )
2179
2181
beta = eta - 1 + n / 2
@@ -2182,16 +2184,9 @@ def ref_rand(size, n, eta):
2182
2184
2183
2185
# If passed as a domain, continuous_random_tester would make `n` a shared variable
2184
2186
# But this RV needs it to be constant in order to define the inner graph
2185
- def lkj_corr_tril (n , eta , shape = None ):
2186
- tril_idx = pt .tril_indices (n )
2187
- return _LKJCorr .dist (n = n , eta = eta , shape = shape )[..., tril_idx [0 ], tril_idx [1 ]]
2188
-
2189
- def SlicedLKJ (name , n , eta , * args , shape = None , ** kwargs ):
2190
- return CustomDist (name , n , eta , dist = lkj_corr_tril , shape = shape )
2191
-
2192
2187
for n in (2 , 10 , 50 ):
2193
2188
continuous_random_tester (
2194
- SlicedLKJ ,
2189
+ _LKJCorr ,
2195
2190
{
2196
2191
"eta" : Domain ([1.0 , 10.0 , 100.0 ], edges = (None , None )),
2197
2192
},
@@ -2204,7 +2199,7 @@ def SlicedLKJ(name, n, eta, *args, shape=None, **kwargs):
2204
2199
@pytest .mark .parametrize ("shape" , [(2 , 2 ), (3 , 2 , 2 )], ids = ["no_batch" , "with_batch" ])
2205
2200
def test_LKJCorr_default_transform (shape ):
2206
2201
with pm .Model () as m :
2207
- x = pm .LKJCorr ("x" , n = 2 , eta = 1 , shape = shape , return_matrix = False )
2202
+ x = pm .LKJCorr ("x" , n = 2 , eta = 1 , shape = shape )
2208
2203
assert isinstance (m .rvs_to_transforms [x ], CholeskyCorrTransform )
2209
2204
assert m .logp (sum = False )[0 ].type .shape == shape [:- 2 ]
2210
2205
0 commit comments