Skip to content

Commit 0d75490

Browse files
committed
Fix TestMatrixNormal.check_draws
Test behavior was accidentally changed in 9dad9c2
1 parent 9956991 commit 0d75490

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

tests/distributions/test_multivariate.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,26 +1907,22 @@ def check_draws(self):
19071907
def ref_rand(mu, rowcov, colcov):
19081908
return st.matrix_normal.rvs(mean=mu, rowcov=rowcov, colcov=colcov)
19091909

1910-
with pm.Model():
1911-
matrixnormal = pm.MatrixNormal(
1912-
"matnormal",
1913-
mu=np.random.random((3, 3)),
1914-
rowcov=np.eye(3),
1915-
colcov=np.eye(3),
1916-
)
1917-
check = pm.sample_prior_predictive(n_fails, return_inferencedata=False, random_seed=1)
1918-
1919-
ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3))
1910+
matrixnormal = pm.MatrixNormal.dist(
1911+
mu=np.random.random((3, 3)),
1912+
rowcov=np.eye(3),
1913+
colcov=np.eye(3),
1914+
)
19201915

19211916
p, f = delta, n_fails
19221917
while p <= delta and f > 0:
1923-
matrixnormal_smp = check["matnormal"]
1918+
matrixnormal_smp = pm.draw(matrixnormal)
1919+
ref_smp = ref_rand(mu=np.random.random((3, 3)), rowcov=np.eye(3), colcov=np.eye(3))
19241920

19251921
p = np.min(
19261922
[
19271923
st.ks_2samp(
1928-
np.atleast_1d(matrixnormal_smp).flatten(),
1929-
np.atleast_1d(ref_smp).flatten(),
1924+
matrixnormal_smp.flatten(),
1925+
ref_smp.flatten(),
19301926
)
19311927
]
19321928
)

0 commit comments

Comments
 (0)