Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 66 additions & 98 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_underlying_scalar_constant_value,
sigmoid,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace
Expand Down Expand Up @@ -73,7 +74,11 @@
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
from pymc.distributions.transforms import (
CholeskyCorrTransform,
ZeroSumTransform,
_default_transform,
)
from pymc.logprob.abstract import _logprob
from pymc.logprob.rewriting import (
specialization_ir_rewrites_db,
Expand Down Expand Up @@ -918,18 +923,15 @@ def posdef(AA):
class PosDefMatrix(Op):
"""Check if input is positive definite. Input should be a square matrix."""

# Properties attribute
__props__ = ()

# Compulsory if itypes and otypes are not defined
gufunc_signature = "(m,m)->()"

def make_node(self, x):
x = pt.as_tensor_variable(x)
assert x.ndim == 2
o = TensorType(dtype="bool", shape=[])()
return Apply(self, [x], [o])

# Python implementation:
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
Expand All @@ -950,7 +952,7 @@ def __str__(self):
return "MatrixIsPositiveDefinite"


matrix_pos_def = PosDefMatrix()
matrix_pos_def = Blockwise(PosDefMatrix())


class WishartRV(RandomVariable):
Expand Down Expand Up @@ -1154,12 +1156,12 @@ def _lkj_normalizing_constant(eta, n):
if not isinstance(n, int):
raise NotImplementedError("n must be an integer")
if eta == 1:
result = gammaln(2.0 * pt.arange(1, int((n - 1) / 2) + 1)).sum()
result = gammaln(2.0 * pt.arange(1, ((n - 1) / 2) + 1)).sum()
if n % 2 == 1:
result += (
0.25 * (n**2 - 1) * pt.log(np.pi)
- 0.25 * (n - 1) ** 2 * pt.log(2.0)
- (n - 1) * gammaln(int((n + 1) / 2))
- (n - 1) * gammaln((n + 1) / 2)
)
else:
result += (
Expand Down Expand Up @@ -1210,12 +1212,8 @@ def rv_op(cls, n, eta, sd_dist, *, size=None):
D = sd_dist.type(name="D") # Make sd_dist opaque to OpFromGraph
size = D.shape[:-1]

# We flatten the size to make operations easier, and then rebuild it
flat_size = pt.prod(size, dtype="int64")

next_rng, C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size)
D_matrix = D.reshape((flat_size, n))
C *= D_matrix[..., :, None] * D_matrix[..., None, :]
next_rng, C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, size=size)
C *= D[..., :, None] * D[..., None, :]

tril_idx = pt.tril_indices(n, k=0)
samples = pt.linalg.cholesky(C)[..., tril_idx[0], tril_idx[1]]
Expand Down Expand Up @@ -1501,7 +1499,7 @@ def helper_deterministics(cls, n, packed_chol):

class LKJCorrRV(SymbolicRandomVariable):
name = "lkjcorr"
extended_signature = "[rng],[size],(),()->[rng],(n)"
extended_signature = "[rng],[size],(),()->[rng],(n,n)"
_print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}")

def make_node(self, rng, size, n, eta):
Expand All @@ -1517,74 +1515,57 @@ def make_node(self, rng, size, n, eta):

@classmethod
def rv_op(cls, n: int, eta, *, rng=None, size=None):
# We flatten the size to make operations easier, and then rebuild it
# HACK: normalize_size_param doesn't handle size=() properly
if not size:
size = None

n = pt.as_tensor(n, ndim=0, dtype=int)
eta = pt.as_tensor(eta, ndim=0)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
flat_size = 1
else:
flat_size = pt.prod(size, dtype="int64")

next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size)

triu_idx = pt.triu_indices(n, k=1)
samples = C[..., triu_idx[0], triu_idx[1]]
next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, size=size)

if rv_size_is_none(size):
samples = samples[0]
else:
dist_shape = (n * (n - 1)) // 2
samples = pt.reshape(samples, (*size, dist_shape))

return cls(
inputs=[rng, size, n, eta],
outputs=[next_rng, samples],
)(rng, size, n, eta)

return samples
return cls(inputs=[rng, size, n, eta], outputs=[next_rng, C])(rng, size, n, eta)

@classmethod
def _random_corr_matrix(
cls, rng: Variable, n: int, eta: TensorVariable, flat_size: TensorVariable
cls, rng: Variable, n: int, eta: TensorVariable, size: TensorVariable
) -> tuple[Variable, TensorVariable]:
# original implementation in R see:
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
size = () if rv_size_is_none(size) else size

beta = eta - 1.0 + n / 2.0
next_rng, beta_rvs = pt.random.beta(
alpha=beta, beta=beta, size=flat_size, rng=rng
).owner.outputs
next_rng, beta_rvs = pt.random.beta(alpha=beta, beta=beta, size=size, rng=rng).owner.outputs
r12 = 2.0 * beta_rvs - 1.0
P = pt.full((flat_size, n, n), pt.eye(n))

P = pt.full((*size, n, n), pt.eye(n))
P = P[..., 0, 1].set(r12)
P = P[..., 1, 1].set(pt.sqrt(1.0 - r12**2))
n = get_underlying_scalar_constant_value(n)

for mp1 in range(2, n):
beta -= 0.5

next_rng, y = pt.random.beta(
alpha=mp1 / 2.0, beta=beta, size=flat_size, rng=next_rng
alpha=mp1 / 2.0, beta=beta, size=size, rng=next_rng
).owner.outputs

next_rng, z = pt.random.normal(
loc=0, scale=1, size=(flat_size, mp1), rng=next_rng
loc=0, scale=1, size=(*size, mp1), rng=next_rng
).owner.outputs
z = z / pt.sqrt(pt.einsum("ij,ij->i", z, z.copy()))[..., np.newaxis]

ein_sig_z = "i, i->" if z.ndim == 1 else "...ij, ...ij->...i"
z = z / pt.sqrt(pt.einsum(ein_sig_z, z, z.copy()))[..., np.newaxis]
P = P[..., 0:mp1, mp1].set(pt.sqrt(y[..., np.newaxis]) * z)
P = P[..., mp1, mp1].set(pt.sqrt(1.0 - y))
C = pt.einsum("...ji,...jk->...ik", P, P.copy())
return next_rng, C


class MultivariateIntervalTransform(Interval):
name = "interval"
C = pt.einsum("...ji,...jk->...ik", P, P.copy())

def log_jac_det(self, *args):
return super().log_jac_det(*args).sum(-1)
return next_rng, C


# Returns list of upper triangular values
class _LKJCorr(BoundedContinuous):
rv_type = LKJCorrRV
rv_op = LKJCorrRV.rv_op
Expand All @@ -1595,10 +1576,12 @@ def dist(cls, n, eta, **kwargs):
eta = pt.as_tensor_variable(eta)
return super().dist([n, eta], **kwargs)

def support_point(rv, *args):
return pt.zeros_like(rv)
@staticmethod
def support_point(rv: TensorVariable, *args):
return pt.broadcast_to(pt.eye(rv.shape[-1]), rv.shape)

def logp(value, n, eta):
@staticmethod
def logp(value: TensorVariable, n, eta):
"""
Calculate logp of LKJ distribution at specified value.

Expand All @@ -1611,31 +1594,20 @@ def logp(value, n, eta):
-------
TensorVariable
"""
if value.ndim > 1:
raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)")

# TODO: PyTensor does not have a `triu_indices`, so we can only work with constant
# n (or else find a different expression)
# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
try:
n = int(get_underlying_scalar_constant_value(n))
except NotScalarConstantError:
raise NotImplementedError("logp only implemented for constant `n`")

shape = n * (n - 1) // 2
tri_index = np.zeros((n, n), dtype="int32")
tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)

value = pt.take(value, tri_index)
value = pt.fill_diagonal(value, 1)

# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
try:
eta = float(get_underlying_scalar_constant_value(eta))
except NotScalarConstantError:
raise NotImplementedError("logp only implemented for constant `eta`")

result = _lkj_normalizing_constant(eta, n)
result += (eta - 1.0) * pt.log(det(value))

return check_parameters(
result,
value >= -1,
Expand All @@ -1647,7 +1619,9 @@ def logp(value, n, eta):

@_default_transform.register(_LKJCorr)
def lkjcorr_default_transform(op, rv):
return MultivariateIntervalTransform(-1.0, 1.0)
rng, shape, n, eta, *_ = rv.owner.inputs = rv.owner.inputs
n = pt.get_scalar_constant_value(n) # Safely extract scalar value without eval
return CholeskyCorrTransform(n=n, upper=False)


class LKJCorr:
Expand All @@ -1670,10 +1644,6 @@ class LKJCorr:
The shape parameter (eta > 0) of the LKJ distribution. eta = 1
implies a uniform distribution of the correlation matrices;
larger values put more weight on matrices with few correlations.
return_matrix : bool, default=False
If True, returns the full correlation matrix.
False only returns the values of the upper triangular matrix excluding
diagonal in a single vector of length n(n-1)/2 for memory efficiency

Notes
-----
Expand All @@ -1688,7 +1658,7 @@ class LKJCorr:
# Define the vector of fixed standard deviations
sds = 3 * np.ones(10)

corr = pm.LKJCorr("corr", eta=4, n=10, return_matrix=True)
corr = pm.LKJCorr("corr", eta=4, n=10)

# Define a new MvNormal with the given correlation matrix
vals = sds * pm.MvNormal("vals", mu=np.zeros(10), cov=corr, shape=10)
Expand All @@ -1698,10 +1668,6 @@ class LKJCorr:
chol = pt.linalg.cholesky(corr)
vals = sds * pt.dot(chol, vals_raw)

# The matrix is internally still sampled as a upper triangular vector
# If you want access to it in matrix form in the trace, add
pm.Deterministic("corr_mat", corr)


References
----------
Expand All @@ -1711,26 +1677,28 @@ class LKJCorr:
100(9), pp.1989-2001.
"""

def __new__(cls, name, n, eta, *, return_matrix=False, **kwargs):
c_vec = _LKJCorr(name, eta=eta, n=n, **kwargs)
if not return_matrix:
return c_vec
else:
return cls.vec_to_corr_mat(c_vec, n)

@classmethod
def dist(cls, n, eta, *, return_matrix=False, **kwargs):
c_vec = _LKJCorr.dist(eta=eta, n=n, **kwargs)
if not return_matrix:
return c_vec
else:
return cls.vec_to_corr_mat(c_vec, n)
def __new__(cls, name, n, eta, **kwargs):
return_matrix = kwargs.pop("return_matrix", None)
if return_matrix is not None:
warnings.warn(
"The `return_matrix` argument is deprecated and has no effect. "
"LKJCorr always returns the correlation matrix.",
DeprecationWarning,
stacklevel=2,
)
return _LKJCorr(name, eta=eta, n=n, **kwargs)

@classmethod
def vec_to_corr_mat(cls, vec, n):
tri = pt.zeros(pt.concatenate([vec.shape[:-1], (n, n)]))
tri = pt.subtensor.set_subtensor(tri[(..., *np.triu_indices(n, 1))], vec)
return tri + pt.moveaxis(tri, -2, -1) + pt.diag(pt.ones(n))
def dist(cls, n, eta, **kwargs):
return_matrix = kwargs.pop("return_matrix", None)
if return_matrix is not None:
warnings.warn(
"The `return_matrix` argument is deprecated and has no effect. "
"LKJCorr always returns the correlation matrix.",
DeprecationWarning,
stacklevel=2,
)
return _LKJCorr.dist(eta=eta, n=n, **kwargs)


class MatrixNormalRV(RandomVariable):
Expand Down
Loading
Loading