Skip to content

Commit 4a0f717

Browse files
LKJCorr is always a matrix
1 parent 6f7acf5 commit 4a0f717

File tree

2 files changed

+70
-96
lines changed

2 files changed

+70
-96
lines changed

pymc/distributions/multivariate.py

Lines changed: 40 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
sigmoid,
3535
)
3636
from pytensor.tensor.blockwise import Blockwise
37+
from pytensor.tensor.einsum import _delta
3738
from pytensor.tensor.elemwise import DimShuffle
3839
from pytensor.tensor.exceptions import NotScalarConstantError
3940
from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace
@@ -76,7 +77,6 @@
7677
)
7778
from pymc.distributions.transforms import (
7879
CholeskyCorrTransform,
79-
Interval,
8080
ZeroSumTransform,
8181
_default_transform,
8282
)
@@ -1157,12 +1157,12 @@ def _lkj_normalizing_constant(eta, n):
11571157
if not isinstance(n, int):
11581158
raise NotImplementedError("n must be an integer")
11591159
if eta == 1:
1160-
result = gammaln(2.0 * pt.arange(1, int((n - 1) / 2) + 1)).sum()
1160+
result = gammaln(2.0 * pt.arange(1, ((n - 1) / 2) + 1)).sum()
11611161
if n % 2 == 1:
11621162
result += (
11631163
0.25 * (n**2 - 1) * pt.log(np.pi)
11641164
- 0.25 * (n - 1) ** 2 * pt.log(2.0)
1165-
- (n - 1) * gammaln(int((n + 1) / 2))
1165+
- (n - 1) * gammaln((n + 1) / 2)
11661166
)
11671167
else:
11681168
result += (
@@ -1504,7 +1504,7 @@ def helper_deterministics(cls, n, packed_chol):
15041504

15051505
class LKJCorrRV(SymbolicRandomVariable):
15061506
name = "lkjcorr"
1507-
extended_signature = "[rng],[size],(),()->[rng],(n)"
1507+
extended_signature = "[rng],[size],(),()->[rng],(n,n)"
15081508
_print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}")
15091509

15101510
def make_node(self, rng, size, n, eta):
@@ -1532,23 +1532,13 @@ def rv_op(cls, n: int, eta, *, rng=None, size=None):
15321532
flat_size = pt.prod(size, dtype="int64")
15331533

15341534
next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size)
1535-
1536-
triu_idx = pt.triu_indices(n, k=1)
1537-
samples = C[..., triu_idx[0], triu_idx[1]]
1538-
1539-
if rv_size_is_none(size):
1540-
samples = samples[0]
1541-
else:
1542-
dist_shape = (n * (n - 1)) // 2
1543-
samples = pt.reshape(samples, (*size, dist_shape))
1535+
C = C[0] if rv_size_is_none(size) else C.reshape((*size, n, n))
15441536

15451537
return cls(
15461538
inputs=[rng, size, n, eta],
1547-
outputs=[next_rng, samples],
1539+
outputs=[next_rng, C],
15481540
)(rng, size, n, eta)
15491541

1550-
return samples
1551-
15521542
@classmethod
15531543
def _random_corr_matrix(
15541544
cls, rng: Variable, n: int, eta: TensorVariable, flat_size: TensorVariable
@@ -1565,6 +1555,7 @@ def _random_corr_matrix(
15651555
P = P[..., 0, 1].set(r12)
15661556
P = P[..., 1, 1].set(pt.sqrt(1.0 - r12**2))
15671557
n = get_underlying_scalar_constant_value(n)
1558+
15681559
for mp1 in range(2, n):
15691560
beta -= 0.5
15701561
next_rng, y = pt.random.beta(
@@ -1577,17 +1568,10 @@ def _random_corr_matrix(
15771568
P = P[..., 0:mp1, mp1].set(pt.sqrt(y[..., np.newaxis]) * z)
15781569
P = P[..., mp1, mp1].set(pt.sqrt(1.0 - y))
15791570
C = pt.einsum("...ji,...jk->...ik", P, P.copy())
1580-
return next_rng, C
1581-
15821571

1583-
class MultivariateIntervalTransform(Interval):
1584-
name = "interval"
1585-
1586-
def log_jac_det(self, *args):
1587-
return super().log_jac_det(*args).sum(-1)
1572+
return next_rng, C
15881573

15891574

1590-
# Returns list of upper triangular values
15911575
class _LKJCorr(BoundedContinuous):
15921576
rv_type = LKJCorrRV
15931577
rv_op = LKJCorrRV.rv_op
@@ -1598,10 +1582,15 @@ def dist(cls, n, eta, **kwargs):
15981582
eta = pt.as_tensor_variable(eta)
15991583
return super().dist([n, eta], **kwargs)
16001584

1601-
def support_point(rv, *args):
1602-
return pt.zeros_like(rv)
1585+
@staticmethod
1586+
def support_point(rv: TensorVariable, *args):
1587+
ndim = rv.ndim
16031588

1604-
def logp(value, n, eta):
1589+
# Batched identity matrix
1590+
return _delta(rv.shape, (ndim - 2, ndim - 1)).astype(int)
1591+
1592+
@staticmethod
1593+
def logp(value: TensorVariable, n, eta):
16051594
"""
16061595
Calculate logp of LKJ distribution at specified value.
16071596
@@ -1614,31 +1603,20 @@ def logp(value, n, eta):
16141603
-------
16151604
TensorVariable
16161605
"""
1617-
if value.ndim > 1:
1618-
raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)")
1619-
1620-
# TODO: PyTensor does not have a `triu_indices`, so we can only work with constant
1621-
# n (or else find a different expression)
1606+
# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
16221607
try:
16231608
n = int(get_underlying_scalar_constant_value(n))
16241609
except NotScalarConstantError:
16251610
raise NotImplementedError("logp only implemented for constant `n`")
16261611

1627-
shape = n * (n - 1) // 2
1628-
tri_index = np.zeros((n, n), dtype="int32")
1629-
tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
1630-
tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
1631-
1632-
value = pt.take(value, tri_index)
1633-
value = pt.fill_diagonal(value, 1)
1634-
1635-
# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
16361612
try:
16371613
eta = float(get_underlying_scalar_constant_value(eta))
16381614
except NotScalarConstantError:
16391615
raise NotImplementedError("logp only implemented for constant `eta`")
1616+
16401617
result = _lkj_normalizing_constant(eta, n)
16411618
result += (eta - 1.0) * pt.log(det(value))
1619+
16421620
return check_parameters(
16431621
result,
16441622
value >= -1,
@@ -1675,10 +1653,6 @@ class LKJCorr:
16751653
The shape parameter (eta > 0) of the LKJ distribution. eta = 1
16761654
implies a uniform distribution of the correlation matrices;
16771655
larger values put more weight on matrices with few correlations.
1678-
return_matrix : bool, default=False
1679-
If True, returns the full correlation matrix.
1680-
False only returns the values of the upper triangular matrix excluding
1681-
diagonal in a single vector of length n(n-1)/2 for memory efficiency
16821656
16831657
Notes
16841658
-----
@@ -1693,7 +1667,7 @@ class LKJCorr:
16931667
# Define the vector of fixed standard deviations
16941668
sds = 3 * np.ones(10)
16951669
1696-
corr = pm.LKJCorr("corr", eta=4, n=10, return_matrix=True)
1670+
corr = pm.LKJCorr("corr", eta=4, n=10)
16971671
16981672
# Define a new MvNormal with the given correlation matrix
16991673
vals = sds * pm.MvNormal("vals", mu=np.zeros(10), cov=corr, shape=10)
@@ -1703,10 +1677,6 @@ class LKJCorr:
17031677
chol = pt.linalg.cholesky(corr)
17041678
vals = sds * pt.dot(chol, vals_raw)
17051679
1706-
# The matrix is internally still sampled as a upper triangular vector
1707-
# If you want access to it in matrix form in the trace, add
1708-
pm.Deterministic("corr_mat", corr)
1709-
17101680
17111681
References
17121682
----------
@@ -1716,26 +1686,28 @@ class LKJCorr:
17161686
100(9), pp.1989-2001.
17171687
"""
17181688

1719-
def __new__(cls, name, n, eta, *, return_matrix=False, **kwargs):
1720-
c_vec = _LKJCorr(name, eta=eta, n=n, **kwargs)
1721-
if not return_matrix:
1722-
return c_vec
1723-
else:
1724-
return cls.vec_to_corr_mat(c_vec, n)
1725-
1726-
@classmethod
1727-
def dist(cls, n, eta, *, return_matrix=False, **kwargs):
1728-
c_vec = _LKJCorr.dist(eta=eta, n=n, **kwargs)
1729-
if not return_matrix:
1730-
return c_vec
1731-
else:
1732-
return cls.vec_to_corr_mat(c_vec, n)
1689+
def __new__(cls, name, n, eta, **kwargs):
1690+
return_matrix = kwargs.pop("return_matrix", None)
1691+
if return_matrix is not None:
1692+
warnings.warn(
1693+
"The `return_matrix` argument is deprecated and has no effect. "
1694+
"LKJCorr always returns the correlation matrix.",
1695+
DeprecationWarning,
1696+
stacklevel=2,
1697+
)
1698+
return _LKJCorr(name, eta=eta, n=n, **kwargs)
17331699

17341700
@classmethod
1735-
def vec_to_corr_mat(cls, vec, n):
1736-
tri = pt.zeros(pt.concatenate([vec.shape[:-1], (n, n)]))
1737-
tri = pt.subtensor.set_subtensor(tri[(..., *np.triu_indices(n, 1))], vec)
1738-
return tri + pt.moveaxis(tri, -2, -1) + pt.diag(pt.ones(n))
1701+
def dist(cls, n, eta, **kwargs):
1702+
return_matrix = kwargs.pop("return_matrix", None)
1703+
if return_matrix is not None:
1704+
warnings.warn(
1705+
"The `return_matrix` argument is deprecated and has no effect. "
1706+
"LKJCorr always returns the correlation matrix.",
1707+
DeprecationWarning,
1708+
stacklevel=2,
1709+
)
1710+
return _LKJCorr.dist(eta=eta, n=n, **kwargs)
17391711

17401712

17411713
class MatrixNormalRV(RandomVariable):

tests/distributions/test_multivariate.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@
3535

3636
from pymc import Model
3737
from pymc.distributions.multivariate import (
38-
MultivariateIntervalTransform,
3938
_LKJCholeskyCov,
4039
_LKJCorr,
4140
_OrderedMultinomial,
4241
posdef,
4342
quaddist_matrix,
4443
)
4544
from pymc.distributions.shape_utils import change_dist_size, to_tuple
45+
from pymc.distributions.transforms import CholeskyCorrTransform
4646
from pymc.logprob.basic import logp
4747
from pymc.logprob.utils import ParameterValueError
4848
from pymc.math import kronecker
@@ -559,10 +559,14 @@ def test_wishart(self, n):
559559
lambda value, nu, V: st.wishart.logpdf(value, int(nu), V),
560560
)
561561

562-
@pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES)
563-
def test_lkjcorr(self, x, eta, n, lp):
562+
@pytest.mark.parametrize("x_tri,eta,n,lp", LKJ_CASES)
563+
def test_lkjcorr(self, x_tri, eta, n, lp):
564564
with pm.Model() as model:
565-
pm.LKJCorr("lkj", eta=eta, n=n, default_transform=None, return_matrix=False)
565+
pm.LKJCorr("lkj", eta=eta, n=n, transform=None)
566+
567+
x = np.eye(n)
568+
x[np.tril_indices(n, -1)] = x_tri
569+
x[np.triu_indices(n, 1)] = x_tri
566570

567571
point = {"lkj": x}
568572
decimals = select_by_precision(float64=6, float32=4)
@@ -2153,13 +2157,13 @@ class TestLKJCorr(BaseTestDistributionRandom):
21532157

21542158
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
21552159
sizes_expected = [
2156-
(3,),
2157-
(3,),
2158-
(1, 3),
2159-
(1, 3),
2160-
(5, 3),
2161-
(4, 5, 3),
2162-
(2, 4, 2, 3),
2160+
(3, 3),
2161+
(3, 3),
2162+
(1, 3, 3),
2163+
(1, 3, 3),
2164+
(5, 3, 3),
2165+
(4, 5, 3, 3),
2166+
(2, 4, 2, 3, 3),
21632167
]
21642168

21652169
checks_to_run = [
@@ -2169,16 +2173,26 @@ class TestLKJCorr(BaseTestDistributionRandom):
21692173
]
21702174

21712175
def check_draws_match_expected(self):
2176+
from pymc.distributions import CustomDist
2177+
21722178
def ref_rand(size, n, eta):
21732179
shape = int(n * (n - 1) // 2)
21742180
beta = eta - 1 + n / 2
2175-
return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2
2181+
tril_values = (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2
2182+
return tril_values
21762183

21772184
# If passed as a domain, continuous_random_tester would make `n` a shared variable
21782185
# But this RV needs it to be constant in order to define the inner graph
2186+
def lkj_corr_tril(n, eta, shape=None):
2187+
tril_idx = pt.tril_indices(n)
2188+
return _LKJCorr.dist(n=n, eta=eta, shape=shape)[..., tril_idx[0], tril_idx[1]]
2189+
2190+
def SlicedLKJ(name, n, eta, *args, shape=None, **kwargs):
2191+
return CustomDist(name, n, eta, dist=lkj_corr_tril, shape=shape)
2192+
21792193
for n in (2, 10, 50):
21802194
continuous_random_tester(
2181-
_LKJCorr,
2195+
SlicedLKJ,
21822196
{
21832197
"eta": Domain([1.0, 10.0, 100.0], edges=(None, None)),
21842198
},
@@ -2188,24 +2202,12 @@ def ref_rand(size, n, eta):
21882202
)
21892203

21902204

2191-
@pytest.mark.parametrize(
2192-
argnames="shape",
2193-
argvalues=[
2194-
(2,),
2195-
pytest.param(
2196-
(3, 2),
2197-
marks=pytest.mark.xfail(
2198-
raises=NotImplementedError,
2199-
reason="LKJCorr logp is only implemented for vector values (ndim=1)",
2200-
),
2201-
),
2202-
],
2203-
)
2205+
@pytest.mark.parametrize("shape", [(2, 2), (3, 2, 2)], ids=["no_batch", "with_batch"])
22042206
def test_LKJCorr_default_transform(shape):
22052207
with pm.Model() as m:
22062208
x = pm.LKJCorr("x", n=2, eta=1, shape=shape, return_matrix=False)
2207-
assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform)
2208-
assert m.logp(sum=False)[0].type.shape == shape[:-1]
2209+
assert isinstance(m.rvs_to_transforms[x], CholeskyCorrTransform)
2210+
assert m.logp(sum=False)[0].type.shape == shape[:-2]
22092211

22102212

22112213
class TestLKJCholeskyCov(BaseTestDistributionRandom):

0 commit comments

Comments
 (0)