Skip to content

Commit 851c991

Browse files
author
Juan Orduz
authored
Fix default transform for LKJCorr (#7065)
1 parent 35cd657 commit 851c991

File tree

3 files changed

+58
-11
lines changed

3 files changed

+58
-11
lines changed

pymc/distributions/multivariate.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,13 @@ def _random_corr_matrix(cls, rng, n, eta, flat_size):
15241524
lkjcorr = LKJCorrRV()
15251525

15261526

1527+
class MultivariateIntervalTransform(Interval):
1528+
name = "interval"
1529+
1530+
def log_jac_det(self, *args):
1531+
return super().log_jac_det(*args).sum(-1)
1532+
1533+
15271534
class LKJCorr(BoundedContinuous):
15281535
r"""
15291536
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
@@ -1592,6 +1599,9 @@ def logp(value, n, eta):
15921599
TensorVariable
15931600
"""
15941601

1602+
if value.ndim > 1:
1603+
raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)")
1604+
15951605
# TODO: PyTensor does not have a `triu_indices`, so we can only work with constant
15961606
# n (or else find a different expression)
15971607
if not isinstance(n, Constant):
@@ -1623,7 +1633,7 @@ def logp(value, n, eta):
16231633

16241634
@_default_transform.register(LKJCorr)
16251635
def lkjcorr_default_transform(op, rv):
1626-
return Interval(floatX(-1.0), floatX(1.0))
1636+
return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0))
16271637

16281638

16291639
class MatrixNormalRV(RandomVariable):

pymc/logprob/transform_value.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,12 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs)
125125
raise NotImplementedError(
126126
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
127127
)
128-
else:
129-
# Check there is no broadcasting between logp and jacobian
130-
if logp.type.broadcastable != log_jac_det.type.broadcastable:
131-
raise ValueError(
132-
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
133-
"There is a bug in the implementation of either one."
134-
)
128+
# Check there is no broadcasting between logp and jacobian
129+
if logp.type.broadcastable != log_jac_det.type.broadcastable:
130+
raise ValueError(
131+
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
132+
"There is a bug in the implementation of either one."
133+
)
135134

136135
if use_jacobian:
137136
if value.name:

tests/distributions/test_multivariate.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import functools as ft
16-
import re
1716
import warnings
1817

1918
import numpy as np
@@ -33,6 +32,7 @@
3332
import pymc as pm
3433

3534
from pymc.distributions.multivariate import (
35+
MultivariateIntervalTransform,
3636
_LKJCholeskyCov,
3737
_OrderedMultinomial,
3838
posdef,
@@ -1306,8 +1306,26 @@ def test_kronecker_normal_moment(self, mu, covs, size, expected):
13061306
[
13071307
(3, 1, None, np.zeros(3)),
13081308
(5, 1, None, np.zeros(10)),
1309-
(3, 1, 1, np.zeros((1, 3))),
1310-
(5, 1, (2, 3), np.zeros((2, 3, 10))),
1309+
pytest.param(
1310+
3,
1311+
1,
1312+
1,
1313+
np.zeros((1, 3)),
1314+
marks=pytest.mark.xfail(
1315+
raises=NotImplementedError,
1316+
reason="LKJCorr logp is only implemented for vector values (ndim=1)",
1317+
),
1318+
),
1319+
pytest.param(
1320+
5,
1321+
1,
1322+
(2, 3),
1323+
np.zeros((2, 3, 10)),
1324+
marks=pytest.mark.xfail(
1325+
raises=NotImplementedError,
1326+
reason="LKJCorr logp is only implemented for vector values (ndim=1)",
1327+
),
1328+
),
13111329
],
13121330
)
13131331
def test_lkjcorr_moment(self, n, eta, size, expected):
@@ -2157,6 +2175,26 @@ def ref_rand(size, n, eta):
21572175
)
21582176

21592177

2178+
@pytest.mark.parametrize(
2179+
argnames="shape",
2180+
argvalues=[
2181+
(2,),
2182+
pytest.param(
2183+
(3, 2),
2184+
marks=pytest.mark.xfail(
2185+
raises=NotImplementedError,
2186+
reason="LKJCorr logp is only implemented for vector values (ndim=1)",
2187+
),
2188+
),
2189+
],
2190+
)
2191+
def test_LKJCorr_default_transform(shape):
2192+
with pm.Model() as m:
2193+
x = pm.LKJCorr("x", n=2, eta=1, shape=shape)
2194+
assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform)
2195+
assert m.logp(sum=False)[0].type.shape == shape[:-1]
2196+
2197+
21602198
class TestLKJCholeskyCov(BaseTestDistributionRandom):
21612199
pymc_dist = _LKJCholeskyCov
21622200
pymc_dist_params = {"n": 3, "eta": 1.0, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}

0 commit comments

Comments
 (0)