diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f76a98546e..39278fcc57 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -33,6 +33,7 @@ get_underlying_scalar_constant_value, sigmoid, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace @@ -73,7 +74,11 @@ rv_size_is_none, to_tuple, ) -from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform +from pymc.distributions.transforms import ( + CholeskyCorrTransform, + ZeroSumTransform, + _default_transform, +) from pymc.logprob.abstract import _logprob from pymc.logprob.rewriting import ( specialization_ir_rewrites_db, @@ -918,10 +923,8 @@ def posdef(AA): class PosDefMatrix(Op): """Check if input is positive definite. Input should be a square matrix.""" - # Properties attribute __props__ = () - - # Compulsory if itypes and otypes are not defined + gufunc_signature = "(m,m)->()" def make_node(self, x): x = pt.as_tensor_variable(x) @@ -929,7 +932,6 @@ def make_node(self, x): o = TensorType(dtype="bool", shape=[])() return Apply(self, [x], [o]) - # Python implementation: def perform(self, node, inputs, outputs): (x,) = inputs (z,) = outputs @@ -950,7 +952,7 @@ def __str__(self): return "MatrixIsPositiveDefinite" -matrix_pos_def = PosDefMatrix() +matrix_pos_def = Blockwise(PosDefMatrix()) class WishartRV(RandomVariable): @@ -1154,12 +1156,12 @@ def _lkj_normalizing_constant(eta, n): if not isinstance(n, int): raise NotImplementedError("n must be an integer") if eta == 1: - result = gammaln(2.0 * pt.arange(1, int((n - 1) / 2) + 1)).sum() + result = gammaln(2.0 * pt.arange(1, ((n - 1) / 2) + 1)).sum() if n % 2 == 1: result += ( 0.25 * (n**2 - 1) * pt.log(np.pi) - 0.25 * (n - 1) ** 2 * pt.log(2.0) - - (n - 1) * gammaln(int((n + 1) / 2)) + - (n - 1) * gammaln((n + 1) / 2) ) else: result += ( @@ -1210,12 +1212,8 @@ def rv_op(cls, n, eta, sd_dist, *, size=None): D = sd_dist.type(name="D") # Make sd_dist opaque to OpFromGraph size = D.shape[:-1] - # We flatten the size to make operations easier, and then rebuild it - flat_size = pt.prod(size, dtype="int64") - - next_rng, C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) - D_matrix = D.reshape((flat_size, n)) - C *= D_matrix[..., :, None] * D_matrix[..., None, :] + next_rng, C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, size=size) + C *= D[..., :, None] * D[..., None, :] tril_idx = pt.tril_indices(n, k=0) samples = pt.linalg.cholesky(C)[..., tril_idx[0], tril_idx[1]] @@ -1501,7 +1499,7 @@ def helper_deterministics(cls, n, packed_chol): class LKJCorrRV(SymbolicRandomVariable): name = "lkjcorr" - extended_signature = "[rng],[size],(),()->[rng],(n)" + extended_signature = "[rng],[size],(),()->[rng],(n,n)" _print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}") def make_node(self, rng, size, n, eta): @@ -1517,74 +1515,57 @@ def make_node(self, rng, size, n, eta): @classmethod def rv_op(cls, n: int, eta, *, rng=None, size=None): - # We flatten the size to make operations easier, and then rebuild it + # HACK: normalize_size_param doesn't handle size=() properly + if not size: + size = None + n = pt.as_tensor(n, ndim=0, dtype=int) eta = pt.as_tensor(eta, ndim=0) rng = normalize_rng_param(rng) size = normalize_size_param(size) - if rv_size_is_none(size): - flat_size = 1 - else: - flat_size = pt.prod(size, dtype="int64") - - next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) - - triu_idx = pt.triu_indices(n, k=1) - samples = C[..., triu_idx[0], triu_idx[1]] + next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, size=size) - if rv_size_is_none(size): - samples = samples[0] - else: - dist_shape = (n * (n - 1)) // 2 - samples = pt.reshape(samples, (*size, dist_shape)) - - return cls( - inputs=[rng, size, n, eta], - outputs=[next_rng, samples], - )(rng, size, n, eta) - - return samples + return cls(inputs=[rng, size, n, eta], outputs=[next_rng, C])(rng, size, n, eta) @classmethod def _random_corr_matrix( - cls, rng: Variable, n: int, eta: TensorVariable, flat_size: TensorVariable + cls, rng: Variable, n: int, eta: TensorVariable, size: TensorVariable ) -> tuple[Variable, TensorVariable]: # original implementation in R see: # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r + size = () if rv_size_is_none(size) else size beta = eta - 1.0 + n / 2.0 - next_rng, beta_rvs = pt.random.beta( - alpha=beta, beta=beta, size=flat_size, rng=rng - ).owner.outputs + next_rng, beta_rvs = pt.random.beta(alpha=beta, beta=beta, size=size, rng=rng).owner.outputs r12 = 2.0 * beta_rvs - 1.0 - P = pt.full((flat_size, n, n), pt.eye(n)) + + P = pt.full((*size, n, n), pt.eye(n)) P = P[..., 0, 1].set(r12) P = P[..., 1, 1].set(pt.sqrt(1.0 - r12**2)) n = get_underlying_scalar_constant_value(n) + for mp1 in range(2, n): beta -= 0.5 + next_rng, y = pt.random.beta( - alpha=mp1 / 2.0, beta=beta, size=flat_size, rng=next_rng + alpha=mp1 / 2.0, beta=beta, size=size, rng=next_rng ).owner.outputs + next_rng, z = pt.random.normal( - loc=0, scale=1, size=(flat_size, mp1), rng=next_rng + loc=0, scale=1, size=(*size, mp1), rng=next_rng ).owner.outputs - z = z / pt.sqrt(pt.einsum("ij,ij->i", z, z.copy()))[..., np.newaxis] + + ein_sig_z = "i, i->" if z.ndim == 1 else "...ij, ...ij->...i" + z = z / pt.sqrt(pt.einsum(ein_sig_z, z, z.copy()))[..., np.newaxis] P = P[..., 0:mp1, mp1].set(pt.sqrt(y[..., np.newaxis]) * z) P = P[..., mp1, mp1].set(pt.sqrt(1.0 - y)) - C = pt.einsum("...ji,...jk->...ik", P, P.copy()) - return next_rng, C - -class MultivariateIntervalTransform(Interval): - name = "interval" + C = pt.einsum("...ji,...jk->...ik", P, P.copy()) - def log_jac_det(self, *args): - return super().log_jac_det(*args).sum(-1) + return next_rng, C -# Returns list of upper triangular values class _LKJCorr(BoundedContinuous): rv_type = LKJCorrRV rv_op = LKJCorrRV.rv_op @@ -1595,10 +1576,12 @@ def dist(cls, n, eta, **kwargs): eta = pt.as_tensor_variable(eta) return super().dist([n, eta], **kwargs) - def support_point(rv, *args): - return pt.zeros_like(rv) + @staticmethod + def support_point(rv: TensorVariable, *args): + return pt.broadcast_to(pt.eye(rv.shape[-1]), rv.shape) - def logp(value, n, eta): + @staticmethod + def logp(value: TensorVariable, n, eta): """ Calculate logp of LKJ distribution at specified value. @@ -1611,31 +1594,20 @@ def logp(value, n, eta): ------- TensorVariable """ - if value.ndim > 1: - raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)") - - # TODO: PyTensor does not have a `triu_indices`, so we can only work with constant - # n (or else find a different expression) + # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants try: n = int(get_underlying_scalar_constant_value(n)) except NotScalarConstantError: raise NotImplementedError("logp only implemented for constant `n`") - shape = n * (n - 1) // 2 - tri_index = np.zeros((n, n), dtype="int32") - tri_index[np.triu_indices(n, k=1)] = np.arange(shape) - tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape) - - value = pt.take(value, tri_index) - value = pt.fill_diagonal(value, 1) - - # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants try: eta = float(get_underlying_scalar_constant_value(eta)) except NotScalarConstantError: raise NotImplementedError("logp only implemented for constant `eta`") + result = _lkj_normalizing_constant(eta, n) result += (eta - 1.0) * pt.log(det(value)) + return check_parameters( result, value >= -1, @@ -1647,7 +1619,9 @@ def logp(value, n, eta): @_default_transform.register(_LKJCorr) def lkjcorr_default_transform(op, rv): - return MultivariateIntervalTransform(-1.0, 1.0) + rng, shape, n, eta, *_ = rv.owner.inputs = rv.owner.inputs + n = pt.get_scalar_constant_value(n) # Safely extract scalar value without eval + return CholeskyCorrTransform(n=n, upper=False) class LKJCorr: @@ -1670,10 +1644,6 @@ class LKJCorr: The shape parameter (eta > 0) of the LKJ distribution. eta = 1 implies a uniform distribution of the correlation matrices; larger values put more weight on matrices with few correlations. - return_matrix : bool, default=False - If True, returns the full correlation matrix. - False only returns the values of the upper triangular matrix excluding - diagonal in a single vector of length n(n-1)/2 for memory efficiency Notes ----- @@ -1688,7 +1658,7 @@ class LKJCorr: # Define the vector of fixed standard deviations sds = 3 * np.ones(10) - corr = pm.LKJCorr("corr", eta=4, n=10, return_matrix=True) + corr = pm.LKJCorr("corr", eta=4, n=10) # Define a new MvNormal with the given correlation matrix vals = sds * pm.MvNormal("vals", mu=np.zeros(10), cov=corr, shape=10) @@ -1698,10 +1668,6 @@ class LKJCorr: chol = pt.linalg.cholesky(corr) vals = sds * pt.dot(chol, vals_raw) - # The matrix is internally still sampled as a upper triangular vector - # If you want access to it in matrix form in the trace, add - pm.Deterministic("corr_mat", corr) - References ---------- @@ -1711,26 +1677,28 @@ class LKJCorr: 100(9), pp.1989-2001. """ - def __new__(cls, name, n, eta, *, return_matrix=False, **kwargs): - c_vec = _LKJCorr(name, eta=eta, n=n, **kwargs) - if not return_matrix: - return c_vec - else: - return cls.vec_to_corr_mat(c_vec, n) - - @classmethod - def dist(cls, n, eta, *, return_matrix=False, **kwargs): - c_vec = _LKJCorr.dist(eta=eta, n=n, **kwargs) - if not return_matrix: - return c_vec - else: - return cls.vec_to_corr_mat(c_vec, n) + def __new__(cls, name, n, eta, **kwargs): + return_matrix = kwargs.pop("return_matrix", None) + if return_matrix is not None: + warnings.warn( + "The `return_matrix` argument is deprecated and has no effect. " + "LKJCorr always returns the correlation matrix.", + DeprecationWarning, + stacklevel=2, + ) + return _LKJCorr(name, eta=eta, n=n, **kwargs) @classmethod - def vec_to_corr_mat(cls, vec, n): - tri = pt.zeros(pt.concatenate([vec.shape[:-1], (n, n)])) - tri = pt.subtensor.set_subtensor(tri[(..., *np.triu_indices(n, 1))], vec) - return tri + pt.moveaxis(tri, -2, -1) + pt.diag(pt.ones(n)) + def dist(cls, n, eta, **kwargs): + return_matrix = kwargs.pop("return_matrix", None) + if return_matrix is not None: + warnings.warn( + "The `return_matrix` argument is deprecated and has no effect. " + "LKJCorr always returns the correlation matrix.", + DeprecationWarning, + stacklevel=2, + ) + return _LKJCorr.dist(eta=eta, n=n, **kwargs) class MatrixNormalRV(RandomVariable): diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index b1e02fd1f9..17d4bd6055 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -19,7 +19,7 @@ from pytensor.graph import Op from pytensor.npy_2_compat import normalize_axis_tuple -from pytensor.tensor import TensorVariable +from pytensor.tensor import TensorLike, TensorVariable from pymc.logprob.transforms import ( ChainedTransform, @@ -33,10 +33,15 @@ __all__ = [ "Chain", + "Chain", + "CholeskyCorrTransform", + "CholeskyCovPacked", "CholeskyCovPacked", "Interval", "Transform", "ZeroSumTransform", + "ZeroSumTransform", + "circular", "circular", "log", "log_exp_m1", @@ -135,6 +140,289 @@ def log_jac_det(self, value, *inputs): return pt.sum(y, axis=-1) +class CholeskyCorrTransform(Transform): + """ + Map an unconstrained real vector the Cholesky factor of a correlation matrix. + + For detailed description of the transform, [1]_ and [2]_. + + This is typically used with :class:`~pymc.distributions.LKJCholeskyCov` to place priors on correlation structures. + For a related transform that additionally rescales diagonal elements (working on covariance factors), see + :class:`~pymc.distributions.transforms.CholeskyCovPacked`. + + Adapted from the implementation in TensorFlow Probability [3]_: + https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31 + + Examples + -------- + + .. code-block:: python + + import numpy as np + import pytensor.tensor as pt + from pymc.distributions.transforms import CholeskyCorr + + unconstrained_vector = pt.as_tensor(np.array([2.0, 2.0, 1.0])) + n = unconstrained_vector.shape[0] + tr = CholeskyCorr(n) + constrained_matrix = tr.forward(unconstrained_vector) + y.eval() + array( + [[1.0, 0.0, 0.0], [0.70710678, 0.70710678, 0.0], [0.66666667, 0.66666667, 0.33333333]] + ) + + References + ---------- + .. [1] Lewandowski, D., Kurowicka, D., & Joe, H. (2009). + Generating random correlation matrices based on vines and extended onion method. + Journal of Multivariate Analysis, 100(9), 1989–2001. + .. [2] Stan Development Team. Stan Functions Reference. Section on LKJ / Cholesky correlation. + .. [3] TensorFlow Probability. Correlation Cholesky bijector implementation. + https://github.com/tensorflow/probability/ + """ + + name = "cholesky_corr" + + def __init__(self, n, upper: bool = False): + """ + Initialize the CholeskyCorr transform. + + Parameters + ---------- + n : int + Size of the correlation matrix. + upper: bool, default False + If True, transform to an upper triangular matrix. If False, transform to a lower triangular matrix. + """ + self.n = n + self.m = (n * (n + 1)) // 2 # Number of triangular elements + self.upper = upper + + super().__init__() + + def _fill_triangular_spiral( + self, x_raveled: TensorLike, unit_diag: bool = True + ) -> TensorVariable: + """ + Create a triangular matrix from a vector by filling it in a spiral order. + + This code is adapted from the `fill_triangular` function in TensorFlow Probability: + https://github.com/tensorflow/probability/blob/a26f4cbe5ce1549767e13798d9bf5032dac4257b/tensorflow_probability/python/math/linalg.py#L925 + + Parameters + ---------- + x_raveled: TensorLike + The input vector to be reshaped into a triangular matrix. + unit_diag: bool, default False + If True, the diagonal elements are assumed to be 1 and are not filled from the input vector. The input + vector is expected to have length m = n * (n - 1) / 2 in this case, containing only the off-diagonal + elements. + + Returns + ------- + triangular_matrix: TensorVariable + The resulting triangular matrix. + + Notes + ----- + By "spiral order", it is meant that the matrix is filled by jumping between the top and bottom rows, flipping + the fill order from left-to-right to right-to-left on each jump. For example, to fill a 4x4 matrix with + `order=True`, the matrix is filled in the following order: + + - Row 0, left to right + - Row 3, right to left + - Row 1, left to right + - Row 2, right to left + + When `upper` if False, everything is reversed: + + - Row 3, right to left + - Row 0, left to right + - Row 2, right to left + - Row 1, left to right + + After filling, entries not part of the triangular matrix are set to zero. + + Examples + -------- + + .. code-block:: python + + import numpy as np + from pymc.distributions.transforms import CholeskyCorr + + tr = CholeskyCorr(n=4) + x_unconstrained = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + tr._fill_triangular_spiral(x_unconstrained, upper=False).eval() + + # Out: + # array([[ 5, 0, 0, 0], + # [ 9, 10, 0, 0], + # [ 8, 7, 6, 0], + # [ 4, 3, 2, 1]]) + """ + x_raveled = pt.as_tensor(x_raveled) + *batch_shape, _ = x_raveled.shape + n, m = self.n, self.m + upper = self.upper + + if unit_diag: + n = n - 1 + + tail = x_raveled[..., n:] + + if upper: + xc = pt.concatenate([x_raveled, pt.flip(tail, -1)], axis=-1) + else: + xc = pt.concatenate([tail, pt.flip(x_raveled, -1)], axis=-1) + + y = pt.reshape(xc, (*batch_shape, n, n)) + return pt.triu(y) if upper else pt.tril(y) + + def _inverse_fill_triangular_spiral( + self, x: TensorLike, unit_diag: bool = True + ) -> TensorVariable: + """ + Inverse operation of `_fill_triangular_spiral`. + + Extracts the elements of a triangular matrix in spiral order and returns them as a vector. For details about + what is meant by "spiral order", see the docstring of `_fill_triangular_spiral`. + + Parameters + ---------- + x: TensorVariable + The input triangular matrix. + unit_diag: bool + If True, the diagonal elements are assumed to be 1 and are not included in the output vector. + + Returns + ------- + x_raveled: TensorVariable + The resulting vector containing the elements of the triangular matrix in spiral order. + """ + x = pt.as_tensor(x) + *batch_shape, _, _ = x.shape + n, m = self.n, self.m + + if unit_diag: + m = m - n + n = n - 1 + + upper = self.upper + + if upper: + initial_elements = x[..., 0, :] + triangular_portion = x[..., 1:, :] + else: + initial_elements = pt.flip(x[..., -1, :], axis=-1) + triangular_portion = x[..., :-1, :] + + rotated_triangular_portion = pt.flip(triangular_portion, axis=(-1, -2)) + consolidated_matrix = triangular_portion + rotated_triangular_portion + end_sequence = pt.reshape( + consolidated_matrix, + (*batch_shape, pt.cast(n * (n - 1), "int64")), + ) + y = pt.concatenate([initial_elements, end_sequence[..., : m - n]], axis=-1) + + return y + + def forward(self, chol_corr_matrix: TensorLike, *inputs): + """ + Transform the Cholesky factor of a correlation matrix into a real-valued vector. + + Parameters + ---------- + chol_corr_matrix : TensorVariable + Cholesky factor of a correlation matrix R = L @ L.T of shape (n,n). + inputs: + Additional input values. Not used; included for signature compatibility with other transformations. + + Returns + ------- + unconstrained_vector: TensorVariable + Real-valued vector of length m = n * (n - 1) / 2. + """ + chol_corr_matrix = pt.as_tensor(chol_corr_matrix) + n = self.n + + # Extract the reciprocal of the row norms from the diagonal. + diag = pt.diagonal(chol_corr_matrix, axis1=-2, axis2=-1)[..., None] + + # Set the diagonal to 0s. + diag_idx = pt.arange(n) + chol_corr_matrix = chol_corr_matrix[..., diag_idx, diag_idx].set(0) + + # Multiply with the norm (or divide by its reciprocal) to recover the + # unconstrained reals in the (strictly) lower triangular part. + unconstrained_matrix = chol_corr_matrix / diag + + # Remove the first row and last column before inverting the fill_triangular_spiral + # transformation. + return self._inverse_fill_triangular_spiral( + unconstrained_matrix[..., 1:, :-1], unit_diag=True + ) + + def backward(self, unconstrained_vector: TensorLike, *inputs): + """ + Transform a real-valued vector of length m = n * (n - 1) / 2 into the Cholesky factor of a correlation matrix. + + Parameters + ---------- + unconstrained_vector : TensorLike + Real-valued vector of length m = n * (n - 1) / 2. + inputs: + Additional input values. Not used; included for signature compatibility with other transformations. + + Returns + ------- + unconstrained_vector: TensorVariable + Unconstrained real numbers. + """ + unconstrained_vector = pt.as_tensor(unconstrained_vector) + chol_corr_matrix = self._fill_triangular_spiral(unconstrained_vector, unit_diag=True) + + # Pad zeros on the top row and right column. + ndim = chol_corr_matrix.ndim + paddings = [*([(0, 0)] * (ndim - 2)), [1, 0], [0, 1]] + chol_corr_matrix = pt.pad(chol_corr_matrix, paddings) + + diag_idx = pt.arange(self.n) + chol_corr_matrix = chol_corr_matrix[..., diag_idx, diag_idx].set(1) + + # Normalize each row to have Euclidean (L2) norm 1. + chol_corr_matrix /= pt.linalg.norm(chol_corr_matrix, axis=-1, ord=2)[..., None] + + return chol_corr_matrix + + def log_jac_det(self, unconstrained_vector: TensorLike, *inputs) -> TensorVariable: + """ + Compute the log determinant of the Jacobian. + + Parameters + ---------- + unconstrained_vector : TensorLike + Real-valued vector of length m = n * (n - 1) / 2. + inputs: + Additional input values. Not used; included for signature compatibility with other transformations. + + Returns + ------- + log_jac_det: TensorVariable + Log determinant of the Jacobian of the transformation. + """ + chol_corr_matrix = self.backward(unconstrained_vector, *inputs) + n = self.n + input_dtype = unconstrained_vector.dtype + + # TODO: tfp has a negative sign here; verify if it is needed + return pt.sum( + pt.arange(2, 2 + n, dtype=input_dtype) + * pt.log(pt.diagonal(chol_corr_matrix, axis1=-2, axis2=-1)), + axis=-1, + ) + + class CholeskyCovPacked(Transform): """Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the log scale.""" diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 2597605bd1..5c1a3b112f 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -35,7 +35,6 @@ from pymc import Model from pymc.distributions.multivariate import ( - MultivariateIntervalTransform, _LKJCholeskyCov, _LKJCorr, _OrderedMultinomial, @@ -43,6 +42,7 @@ quaddist_matrix, ) from pymc.distributions.shape_utils import change_dist_size, to_tuple +from pymc.distributions.transforms import CholeskyCorrTransform from pymc.logprob.basic import logp from pymc.logprob.utils import ParameterValueError from pymc.math import kronecker @@ -559,10 +559,14 @@ def test_wishart(self, n): lambda value, nu, V: st.wishart.logpdf(value, int(nu), V), ) - @pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES) - def test_lkjcorr(self, x, eta, n, lp): + @pytest.mark.parametrize("x_tri,eta,n,lp", LKJ_CASES) + def test_lkjcorr(self, x_tri, eta, n, lp): with pm.Model() as model: - pm.LKJCorr("lkj", eta=eta, n=n, default_transform=None, return_matrix=False) + pm.LKJCorr("lkj", eta=eta, n=n, transform=None) + + x = np.eye(n) + x[np.tril_indices(n, -1)] = x_tri + x[np.triu_indices(n, 1)] = x_tri point = {"lkj": x} decimals = select_by_precision(float64=6, float32=4) @@ -1308,17 +1312,17 @@ def test_kronecker_normal_support_point(self, mu, covs, size, expected): @pytest.mark.parametrize( "n, eta, size, expected", [ - (3, 1, None, np.zeros(3)), - (5, 1, None, np.zeros(10)), - pytest.param(3, 1, 1, np.zeros((1, 3))), - pytest.param(5, 1, (2, 3), np.zeros((2, 3, 10))), + (3, 1, None, np.eye(3)), + (5, 1, None, np.eye(5)), + (3, 1, (1,), np.broadcast_to(np.eye(3), (1, 3, 3))), + (5, 1, (2, 3), np.broadcast_to(np.eye(5), (2, 3, 5, 5))), ], + ids=["n=3", "n=5", "batch_1", "batch_2"], ) def test_lkjcorr_support_point(self, n, eta, size, expected): with pm.Model() as model: - pm.LKJCorr("x", n=n, eta=eta, size=size, return_matrix=False) - # LKJCorr logp is only implemented for vector values (size=None) - assert_support_point_is_expected(model, expected, check_finite_logp=size is None) + pm.LKJCorr("x", n=n, eta=eta, size=size) + assert_support_point_is_expected(model, expected, check_finite_logp=True) @pytest.mark.parametrize( "n, eta, size, expected", @@ -1462,17 +1466,20 @@ def test_with_lkjcorr_matrix( self, ): with pm.Model() as model: - corr = pm.LKJCorr("corr", n=3, eta=2, return_matrix=True) - pm.Deterministic("corr_mat", corr) - mv = pm.MvNormal("mv", 0.0, cov=corr, size=4) + corr_mat = pm.LKJCorr("corr_mat", n=3, eta=2) + mv = pm.MvNormal("mv", 0.0, cov=corr_mat, size=4) prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False) assert prior["corr_mat"].shape == (10, 3, 3) # square - assert (prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]] == 1.0).all() # 1.0 on diagonal assert (prior["corr_mat"] == prior["corr_mat"].transpose(0, 2, 1)).all() # symmetric - assert ( - prior["corr_mat"].max() <= 1.0 and prior["corr_mat"].min() >= -1.0 - ) # constrained between -1 and 1 + + np.testing.assert_allclose( + prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]], 1.0 + ) # 1.0 on diagonal + + # constrained between -1 and 1 + assert prior["corr_mat"].max() <= (1.0 + 1e-12) + assert prior["corr_mat"].min() >= (-1.0 - 1e-12) def test_issue_3758(self): np.random.seed(42) @@ -2153,13 +2160,13 @@ class TestLKJCorr(BaseTestDistributionRandom): sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)] sizes_expected = [ - (3,), - (3,), - (1, 3), - (1, 3), - (5, 3), - (4, 5, 3), - (2, 4, 2, 3), + (3, 3), + (3, 3), + (1, 3, 3), + (1, 3, 3), + (5, 3, 3), + (4, 5, 3, 3), + (2, 4, 2, 3, 3), ] checks_to_run = [ @@ -2172,7 +2179,8 @@ def check_draws_match_expected(self): def ref_rand(size, n, eta): shape = int(n * (n - 1) // 2) beta = eta - 1 + n / 2 - return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2 + tril_values = (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2 + return tril_values # If passed as a domain, continuous_random_tester would make `n` a shared variable # But this RV needs it to be constant in order to define the inner graph @@ -2188,24 +2196,12 @@ def ref_rand(size, n, eta): ) -@pytest.mark.parametrize( - argnames="shape", - argvalues=[ - (2,), - pytest.param( - (3, 2), - marks=pytest.mark.xfail( - raises=NotImplementedError, - reason="LKJCorr logp is only implemented for vector values (ndim=1)", - ), - ), - ], -) +@pytest.mark.parametrize("shape", [(2, 2), (3, 2, 2)], ids=["no_batch", "with_batch"]) def test_LKJCorr_default_transform(shape): with pm.Model() as m: - x = pm.LKJCorr("x", n=2, eta=1, shape=shape, return_matrix=False) - assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform) - assert m.logp(sum=False)[0].type.shape == shape[:-1] + x = pm.LKJCorr("x", n=2, eta=1, shape=shape) + assert isinstance(m.rvs_to_transforms[x], CholeskyCorrTransform) + assert m.logp(sum=False)[0].type.shape == shape[:-2] class TestLKJCholeskyCov(BaseTestDistributionRandom): diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 26bc8b1bf8..d01971f6d2 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -665,3 +665,90 @@ def log_jac_det(self, value, *inputs): match="are not allowed to broadcast together. There is a bug in the implementation of either one", ): m.logp(jacobian=jacobian_val) + + +class TestLJKCholeskyCorrTransform: + def _get_test_values(self): + x_unconstrained = np.array([2.0, 2.0, 1.0]) + x_constrained = np.array( + [[1.0, 0.0, 0.0], [0.70710678, 0.70710678, 0.0], [0.66666667, 0.66666667, 0.33333333]] + ) + return x_unconstrained, x_constrained + + @pytest.mark.parametrize("upper", [True, False], ids=["upper", "lower"]) + def test_fill_triangular_spiral(self, upper): + x_unconstrained = np.array([1, 2, 3, 4, 5, 6]) + + if upper: + x_constrained = np.array( + [ + [1, 2, 3], + [0, 5, 6], + [0, 0, 4], + ] + ) + else: + x_constrained = np.array( + [ + [4, 0, 0], + [6, 5, 0], + [3, 2, 1], + ] + ) + + transform = tr.CholeskyCorrTransform(n=3, upper=upper) + + np.testing.assert_allclose( + transform._fill_triangular_spiral(x_unconstrained, unit_diag=False).eval(), + x_constrained, + ) + + np.testing.assert_allclose( + transform._inverse_fill_triangular_spiral(x_constrained, unit_diag=False).eval(), + x_unconstrained, + ) + + def test_forward(self): + transform = tr.CholeskyCorrTransform(n=3, upper=False) + x_unconstrained, x_constrained = self._get_test_values() + + np.testing.assert_allclose( + transform.forward(x_constrained).eval(), + x_unconstrained, + atol=1e-6, + ) + + def test_backward(self): + transform = tr.CholeskyCorrTransform(n=3, upper=False) + x_unconstrained, x_constrained = self._get_test_values() + + np.testing.assert_allclose( + transform.backward(x_unconstrained).eval(), + x_constrained, + atol=1e-6, + ) + + def test_transform_round_trip(self): + transform = tr.CholeskyCorrTransform(n=3, upper=False) + x_unconstrained, x_constrained = self._get_test_values() + + constrained_reconstructed = transform.backward(transform.forward(x_constrained)).eval() + unconstrained_reconstructed = transform.forward(transform.backward(x_unconstrained)).eval() + + np.testing.assert_allclose(x_unconstrained, unconstrained_reconstructed, atol=1e-6) + np.testing.assert_allclose(x_constrained, constrained_reconstructed, atol=1e-6) + + def test_log_jac_det(self): + transform = tr.CholeskyCorrTransform(n=3, upper=False) + x_unconstrained, x_constrained = self._get_test_values() + + computed_log_jac_det = transform.log_jac_det(x_unconstrained).eval() + + x = pt.tensor("x", shape=(3,)) + lower_tri_vec = transform.backward(x)[pt.tril_indices(x.shape[0], k=-1)].ravel() + jac = pt.jacobian(lower_tri_vec, x, vectorize=True) + _, autodiff_log_jac_det = pt.linalg.slogdet(jac) + + np.testing.assert_allclose( + autodiff_log_jac_det.eval({x: x_unconstrained}), computed_log_jac_det, atol=1e-6 + )