From 5cc74d180b50f8cc8bbeb45c3dcd7154ebb71fc1 Mon Sep 17 00:00:00 2001 From: John Cant Date: Fri, 21 Jun 2024 14:35:12 +0100 Subject: [PATCH 01/10] Port TF bijector to ensure posdef LKJCorr samples --- pymc/distributions/multivariate.py | 4 +- pymc/distributions/transforms.py | 115 +++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f76a98546e..e6607e0624 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1647,7 +1647,9 @@ def logp(value, n, eta): @_default_transform.register(_LKJCorr) def lkjcorr_default_transform(op, rv): - return MultivariateIntervalTransform(-1.0, 1.0) + _, _, _, n, *_ = rv.owner.inputs + n = n.eval() + return transforms.CholeskyCorr(n) class LKJCorr: diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index b1e02fd1f9..730509707d 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -16,6 +16,7 @@ import numpy as np import pytensor.tensor as pt +import pytensor from pytensor.graph import Op from pytensor.npy_2_compat import normalize_axis_tuple @@ -44,6 +45,11 @@ "ordered", "simplex", "sum_to_1", + "circular", + "CholeskyCorr", + "CholeskyCovPacked", + "Chain", + "ZeroSumTransform", ] @@ -135,6 +141,115 @@ def log_jac_det(self, value, *inputs): return pt.sum(y, axis=-1) +class CholeskyCorr(Transform): + """ + Transforms the off-diagonal elements of a correlation matrix to + unconstrained real numbers. + + Note: This is not particular to the LKJ distribution - it is only a + transform to help generate cholesky decompositions for random valid + correlation matrices. + + Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31 + + The backward side of this transformation is the off-diagonal upper + triangular elements of a correlation matrix, specified in row major order. + """ + + name = "cholesky-corr" + + def __init__(self, n): + """ + + Parameters + ---------- + n: int + Size of correlation matrix + """ + self.n = n + self.m = int(n*(n-1)/2) # number of off-diagonal elements + self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() + self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() + + def _generate_tril_indices(self): + row_indices, col_indices = np.tril_indices(self.n, -1) + return ( + pytensor.shared(row_indices), + pytensor.shared(col_indices) + ) + + def _generate_triu_indices(self): + row_indices, col_indices = np.triu_indices(self.n, 1) + return ( + pytensor.shared(row_indices), + pytensor.shared(col_indices) + ) + + def _jacobian(self, value, *inputs): + return pt.jacobian( + self.backward(value), + wrt=value + ) + + def log_jac_det(self, value, *inputs): + """ + Compute log of the determinant of the jacobian. + + There are no clever tricks here - we literally compute the jacobian + then compute its determinant then take log. + """ + jac = self._jacobian(value) + return pt.log(pt.linalg.det(jac)) + + def forward(self, value, *inputs): + """ + Convert the off-diagonal elements of a cholesky decomposition of a + correlation matrix to unconstrained real numbers. + """ + # The correlation matrix is specified via its upper triangular elements + corr = pt.set_subtensor( + pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs], + value + ) + corr = corr + corr.T + pt.eye(self.n) + + chol = pt.linalg.cholesky(corr) + + # Are the diagonals always guaranteed to be positive? + # I don't know, so we'll use abs + row_norms = 1/pt.abs(pt.diag(chol)) + + # Multiply by the row norms to undo the normalization + unconstrained = chol*row_norms[:, pt.newaxis] + + return unconstrained[self.tril_r_idxs, self.tril_c_idxs] + + def backward(self, value, *inputs, foo=False): + """ + Convert unconstrained real numbers to the off-diagonal elements of the + cholesky decomposition of a correlation matrix. + """ + # The diagonals of this matrix are 1, but these ones are just used for + # computing a denominator. The diagonals of the cholesky factor are not + # returned, but they are not ones. + chol_pre_norm = pt.set_subtensor( + pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs], + value + ) + + # derivative of pt.linalg.norm ended up complex, which caused errors +# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX") + + row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5) + chol = chol_pre_norm / row_norm[:, pt.newaxis] + + # Undo the cholesky decomposition + corr = pt.matmul(chol, chol.T) + + # We want the upper triangular indices here. + return corr[self.triu_r_idxs, self.triu_c_idxs] + + class CholeskyCovPacked(Transform): """Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the log scale.""" From 1ade4a4be95d556f64db6f953d5f9c7bbf3a8313 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Sat, 14 Sep 2024 17:07:12 +0100 Subject: [PATCH 02/10] Use GPT o1 to finish PR. --- pymc/distributions/multivariate.py | 4 +- pymc/distributions/transforms.py | 150 ++++++++++++++------------ tests/distributions/test_transform.py | 136 +++++++++++++++++++++++ 3 files changed, 221 insertions(+), 69 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index e6607e0624..1d844e80f6 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -1648,8 +1648,8 @@ def logp(value, n, eta): @_default_transform.register(_LKJCorr) def lkjcorr_default_transform(op, rv): _, _, _, n, *_ = rv.owner.inputs - n = n.eval() - return transforms.CholeskyCorr(n) + n = pt.get_scalar_constant_value(n) # Safely extract scalar value without eval + return CholeskyCorr(n) class LKJCorr: diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 730509707d..84f8957c52 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -143,111 +143,127 @@ def log_jac_det(self, value, *inputs): class CholeskyCorr(Transform): """ - Transforms the off-diagonal elements of a correlation matrix to - unconstrained real numbers. + Transforms unconstrained real numbers to the off-diagonal elements of + a Cholesky decomposition of a correlation matrix. - Note: This is not particular to the LKJ distribution - it is only a - transform to help generate cholesky decompositions for random valid - correlation matrices. + This ensures that the resulting correlation matrix is positive definite. - Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31 + #### Mathematical Details - The backward side of this transformation is the off-diagonal upper - triangular elements of a correlation matrix, specified in row major order. + [Include detailed mathematical explanations similar to the original TFP bijector.] + + #### Examples + + ```python + transform = CholeskyCorr(n=3) + x = pt.as_tensor_variable([0.0, 0.0, 0.0]) + y = transform.forward(x).eval() + # y will be the off-diagonal elements of the Cholesky factor + + x_reconstructed = transform.backward(y).eval() + # x_reconstructed should closely match the original x + ``` + + #### References + - [Stan Manual. Section 24.2. Cholesky LKJ Correlation Distribution.](https://mc-stan.org/docs/2_18/functions-reference/cholesky-lkj-correlation-distribution.html) + - Lewandowski, D., Kurowicka, D., & Joe, H. (2009). "Generating random correlation matrices based on vines and extended onion method." *Journal of Multivariate Analysis, 100*(5), 1989-2001. """ name = "cholesky-corr" - def __init__(self, n): + def __init__(self, n, validate_args=False): """ + Initialize the CholeskyCorr transform. Parameters ---------- - n: int - Size of correlation matrix + n : int + Size of the correlation matrix. + validate_args : bool, default False + Whether to validate input arguments. """ self.n = n - self.m = int(n*(n-1)/2) # number of off-diagonal elements + self.m = int(n * (n - 1) / 2) # Number of off-diagonal elements self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() + super().__init__(validate_args=validate_args) def _generate_tril_indices(self): row_indices, col_indices = np.tril_indices(self.n, -1) - return ( - pytensor.shared(row_indices), - pytensor.shared(col_indices) - ) + return (row_indices, col_indices) def _generate_triu_indices(self): row_indices, col_indices = np.triu_indices(self.n, 1) - return ( - pytensor.shared(row_indices), - pytensor.shared(col_indices) - ) - - def _jacobian(self, value, *inputs): - return pt.jacobian( - self.backward(value), - wrt=value - ) + return (row_indices, col_indices) - def log_jac_det(self, value, *inputs): + def forward(self, x, *inputs): """ - Compute log of the determinant of the jacobian. + Forward transform: Unconstrained real numbers to Cholesky factors. - There are no clever tricks here - we literally compute the jacobian - then compute its determinant then take log. - """ - jac = self._jacobian(value) - return pt.log(pt.linalg.det(jac)) + Parameters + ---------- + x : tensor + Unconstrained real numbers. - def forward(self, value, *inputs): + Returns + ------- + tensor + Transformed Cholesky factors. """ - Convert the off-diagonal elements of a cholesky decomposition of a - correlation matrix to unconstrained real numbers. - """ - # The correlation matrix is specified via its upper triangular elements - corr = pt.set_subtensor( - pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs], - value + # Initialize a zero matrix + chol = pt.zeros((self.n, self.n), dtype=x.dtype) + + # Assign the unconstrained values to the lower triangular part + chol = pt.set_subtensor( + chol[self.tril_r_idxs, self.tril_c_idxs], + x ) - corr = corr + corr.T + pt.eye(self.n) - chol = pt.linalg.cholesky(corr) + # Normalize each row to have unit L2 norm + row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1, keepdims=True)) + chol = chol / row_norms - # Are the diagonals always guaranteed to be positive? - # I don't know, so we'll use abs - row_norms = 1/pt.abs(pt.diag(chol)) + return chol[self.tril_r_idxs, self.tril_c_idxs] - # Multiply by the row norms to undo the normalization - unconstrained = chol*row_norms[:, pt.newaxis] + def backward(self, y, *inputs): + """ + Backward transform: Cholesky factors to unconstrained real numbers. - return unconstrained[self.tril_r_idxs, self.tril_c_idxs] + Parameters + ---------- + y : tensor + Cholesky factors. - def backward(self, value, *inputs, foo=False): - """ - Convert unconstrained real numbers to the off-diagonal elements of the - cholesky decomposition of a correlation matrix. + Returns + ------- + tensor + Unconstrained real numbers. """ - # The diagonals of this matrix are 1, but these ones are just used for - # computing a denominator. The diagonals of the cholesky factor are not - # returned, but they are not ones. - chol_pre_norm = pt.set_subtensor( - pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs], - value + # Reconstruct the full Cholesky matrix + chol = pt.zeros((self.n, self.n), dtype=y.dtype) + chol = pt.set_subtensor( + chol[self.triu_r_idxs, self.triu_c_idxs], + y ) + chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype) + + # Perform Cholesky decomposition + chol = pt.linalg.cholesky(chol) - # derivative of pt.linalg.norm ended up complex, which caused errors -# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX") + # Extract the unconstrained parameters by normalizing + row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1)) + unconstrained = chol / row_norms[:, None] - row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5) - chol = chol_pre_norm / row_norm[:, pt.newaxis] + return unconstrained[self.tril_r_idxs, self.tril_c_idxs] - # Undo the cholesky decomposition - corr = pt.matmul(chol, chol.T) + def log_jac_det(self, y, *inputs): + """ + Compute the log determinant of the Jacobian. - # We want the upper triangular indices here. - return corr[self.triu_r_idxs, self.triu_c_idxs] + The Jacobian determinant for normalization is the product of row norms. + """ + row_norms = pt.sqrt(pt.sum(y ** 2, axis=1)) + return -pt.sum(pt.log(row_norms), axis=-1) class CholeskyCovPacked(Transform): diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 26bc8b1bf8..ee68db9b5c 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -23,6 +23,7 @@ import pymc as pm import pymc.distributions.transforms as tr +from pymc.distributions.transforms import CholeskyCorr from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import Transform @@ -665,3 +666,138 @@ def log_jac_det(self, value, *inputs): match="are not allowed to broadcast together. There is a bug in the implementation of either one", ): m.logp(jacobian=jacobian_val) + + +def test_lkjcorr_transform_round_trip(): + """ + Test that applying the forward transform followed by the backward transform + retrieves the original unconstrained parameters, and that sampled matrices are positive definite. + """ + with pm.Model() as model: + rho = pm.LKJCorr("rho", n=3, eta=2) + + trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False) + + # Extract the sampled correlation matrices + rho_samples = trace["rho"] + num_samples = rho_samples.shape[0] + + for i in range(num_samples): + sample_matrix = rho_samples[i] + + # Check if the sampled matrix is positive definite + try: + np.linalg.cholesky(sample_matrix) + except np.linalg.LinAlgError: + pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.") + + # Perform round-trip transform: forward and then backward + transform = CholeskyCorr(n=3) + unconstrained = transform.forward(pt.as_tensor_variable(sample_matrix)).eval() + reconstructed = transform.backward(unconstrained).eval() + + # Assert that the original and reconstructed unconstrained parameters are close + assert_allclose(sample_matrix, reconstructed, atol=1e-6) + + +def test_lkjcorr_log_jac_det(): + """ + Verify that the computed log determinant of the Jacobian matches the expected closed-form solution. + """ + n = 3 + transform = CholeskyCorr(n=n) + + # Create a sample unconstrained vector (all zeros for simplicity) + x = np.zeros(int(n * (n - 1) / 2), dtype=pytensor.config.floatX) + x_tensor = pt.as_tensor_variable(x) + + # Perform forward transform to obtain Cholesky factors + y = transform.forward(x_tensor).eval() + + # Compute the log determinant using the transform's method + computed_log_jac_det = transform.log_jac_det(y).eval() + + # Expected log determinant: 0 (since row norms are 1) + expected_log_jac_det = 0.0 + + assert_allclose(computed_log_jac_det, expected_log_jac_det, atol=1e-6) + + +@pytest.mark.parametrize("n", [2, 4, 5]) +def test_lkjcorr_transform_various_sizes(n): + """ + Test the CholeskyCorr transform with various sizes of correlation matrices. + """ + transform = CholeskyCorr(n=n) + unconstrained_size = int(n * (n - 1) / 2) + + # Generate random unconstrained real numbers + x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX) + x_tensor = pt.as_tensor_variable(x) + + # Perform forward transform + y = transform.forward(x_tensor).eval() + + # Perform backward transform + reconstructed = transform.backward(y).eval() + + # Assert that the original and reconstructed unconstrained parameters are close + assert_allclose(x, reconstructed, atol=1e-6) + + +def test_lkjcorr_invalid_n(): + """ + Test that initializing CholeskyCorr with invalid 'n' values raises appropriate errors. + """ + with pytest.raises(ValueError): + # 'n' must be an integer greater than 1 + CholeskyCorr(n=1) + + with pytest.raises(TypeError): + # 'n' must be an integer + CholeskyCorr(n='three') + + +def test_lkjcorr_positive_definite(): + """ + Ensure that all sampled correlation matrices are positive definite. + """ + with pm.Model() as model: + rho = pm.LKJCorr("rho", n=4, eta=2) + + trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False) + + # Extract the sampled correlation matrices + rho_samples = trace["rho"] + num_samples = rho_samples.shape[0] + + for i in range(num_samples): + sample_matrix = rho_samples[i] + + # Check if the sampled matrix is positive definite + try: + np.linalg.cholesky(sample_matrix) + except np.linalg.LinAlgError: + pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.") + + +def test_lkjcorr_round_trip_various_sizes(): + """ + Perform round-trip transformation tests for various sizes of correlation matrices. + """ + for n in [2, 3, 4]: + transform = CholeskyCorr(n=n) + unconstrained_size = int(n * (n - 1) / 2) + + # Generate random unconstrained real numbers + x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX) + x_tensor = pt.as_tensor_variable(x) + + # Perform forward transform + y = transform.forward(x_tensor).eval() + + # Perform backward transform + reconstructed = transform.backward(y).eval() + + # Assert that the original and reconstructed unconstrained parameters are close + assert_allclose(x, reconstructed, atol=1e-6) \ No newline at end of file From 55313abe1e6976ff948bc075aaf509428454d9c9 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Sat, 14 Sep 2024 17:17:30 +0100 Subject: [PATCH 03/10] Linter fixes. --- pymc/distributions/multivariate.py | 7 ++++++- pymc/distributions/transforms.py | 17 +++++------------ tests/distributions/test_transform.py | 14 +++++++++----- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 1d844e80f6..80d7d8f415 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -73,7 +73,12 @@ rv_size_is_none, to_tuple, ) -from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform +from pymc.distributions.transforms import ( + CholeskyCorr, + Interval, + ZeroSumTransform, + _default_transform, +) from pymc.logprob.abstract import _logprob from pymc.logprob.rewriting import ( specialization_ir_rewrites_db, diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 84f8957c52..27ccfc9415 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -16,7 +16,6 @@ import numpy as np import pytensor.tensor as pt -import pytensor from pytensor.graph import Op from pytensor.npy_2_compat import normalize_axis_tuple @@ -214,13 +213,10 @@ def forward(self, x, *inputs): chol = pt.zeros((self.n, self.n), dtype=x.dtype) # Assign the unconstrained values to the lower triangular part - chol = pt.set_subtensor( - chol[self.tril_r_idxs, self.tril_c_idxs], - x - ) + chol = pt.set_subtensor(chol[self.tril_r_idxs, self.tril_c_idxs], x) # Normalize each row to have unit L2 norm - row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1, keepdims=True)) + row_norms = pt.sqrt(pt.sum(chol**2, axis=1, keepdims=True)) chol = chol / row_norms return chol[self.tril_r_idxs, self.tril_c_idxs] @@ -241,17 +237,14 @@ def backward(self, y, *inputs): """ # Reconstruct the full Cholesky matrix chol = pt.zeros((self.n, self.n), dtype=y.dtype) - chol = pt.set_subtensor( - chol[self.triu_r_idxs, self.triu_c_idxs], - y - ) + chol = pt.set_subtensor(chol[self.triu_r_idxs, self.triu_c_idxs], y) chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype) # Perform Cholesky decomposition chol = pt.linalg.cholesky(chol) # Extract the unconstrained parameters by normalizing - row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1)) + row_norms = pt.sqrt(pt.sum(chol**2, axis=1)) unconstrained = chol / row_norms[:, None] return unconstrained[self.tril_r_idxs, self.tril_c_idxs] @@ -262,7 +255,7 @@ def log_jac_det(self, y, *inputs): The Jacobian determinant for normalization is the product of row norms. """ - row_norms = pt.sqrt(pt.sum(y ** 2, axis=1)) + row_norms = pt.sqrt(pt.sum(y**2, axis=1)) return -pt.sum(pt.log(row_norms), axis=-1) diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index ee68db9b5c..55de3249cc 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -23,8 +23,8 @@ import pymc as pm import pymc.distributions.transforms as tr -from pymc.distributions.transforms import CholeskyCorr +from pymc.distributions.transforms import CholeskyCorr from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import Transform from pymc.pytensorf import floatX, jacobian @@ -676,7 +676,9 @@ def test_lkjcorr_transform_round_trip(): with pm.Model() as model: rho = pm.LKJCorr("rho", n=3, eta=2) - trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False) + trace = pm.sample( + 100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False + ) # Extract the sampled correlation matrices rho_samples = trace["rho"] @@ -755,7 +757,7 @@ def test_lkjcorr_invalid_n(): with pytest.raises(TypeError): # 'n' must be an integer - CholeskyCorr(n='three') + CholeskyCorr(n="three") def test_lkjcorr_positive_definite(): @@ -765,7 +767,9 @@ def test_lkjcorr_positive_definite(): with pm.Model() as model: rho = pm.LKJCorr("rho", n=4, eta=2) - trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False) + trace = pm.sample( + 100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False + ) # Extract the sampled correlation matrices rho_samples = trace["rho"] @@ -800,4 +804,4 @@ def test_lkjcorr_round_trip_various_sizes(): reconstructed = transform.backward(y).eval() # Assert that the original and reconstructed unconstrained parameters are close - assert_allclose(x, reconstructed, atol=1e-6) \ No newline at end of file + assert_allclose(x, reconstructed, atol=1e-6) From 8c4e1ad035665406ffa631259795204d396ee6f4 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Sun, 15 Sep 2024 08:48:18 +0100 Subject: [PATCH 04/10] Update doc string. Ask o1-mini to improve test. --- pymc/distributions/transforms.py | 52 ++++++++++++++++++++++++++- tests/distributions/test_transform.py | 27 ++++++++++---- 2 files changed, 71 insertions(+), 8 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 27ccfc9415..25064a0ba1 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -149,7 +149,57 @@ class CholeskyCorr(Transform): #### Mathematical Details - [Include detailed mathematical explanations similar to the original TFP bijector.] + This bijector provides a change of variables from unconstrained reals to a + parameterization of the CholeskyLKJ distribution. The CholeskyLKJ distribution + [1] is a distribution on the set of Cholesky factors of positive definite + correlation matrices. The CholeskyLKJ probability density function is + obtained from the LKJ density on n x n matrices as follows: + + 1 = int p(A | eta) dA + = int Z(eta) * det(A) ** (eta - 1) dA + = int Z(eta) L_ii ** {(n - i - 1) + 2 * (eta - 1)} ^dL_ij (0 <= i < j < n) + + where Z(eta) is the normalizer; the matrix L is the Cholesky factor of the + correlation matrix A; and ^dL_ij denotes the wedge product (or differential) + of the strictly lower triangular entries of L. The entries L_ij are + constrained such that each entry lies in [-1, 1] and the norm of each row is + 1. The norm includes the diagonal; which is not included in the wedge product. + To preserve uniqueness, we further specify that the diagonal entries are + positive. + + The image of unconstrained reals under the `CorrelationCholesky` bijector is + the set of correlation matrices which are positive definite. A [correlation + matrix](https://en.wikipedia.org/wiki/Correlation_and_dependence#Correlation_matrices) + can be characterized as a symmetric positive semidefinite matrix with 1s on + the main diagonal. + + For a lower triangular matrix `L` to be a valid Cholesky-factor of a positive + definite correlation matrix, it is necessary and sufficient that each row of + `L` have unit Euclidean norm [1]. To see this, observe that if `L_i` is the + `i`th row of the Cholesky factor corresponding to the correlation matrix `R`, + then the `i`th diagonal entry of `R` satisfies: + + 1 = R_i,i = L_i . L_i = ||L_i||^2 + + where '.' is the dot product of vectors and `||...||` denotes the Euclidean + norm. + + Furthermore, observe that `R_i,j` lies in the interval `[-1, 1]`. By the + Cauchy-Schwarz inequality: + + |R_i,j| = |L_i . L_j| <= ||L_i|| ||L_j|| = 1 + + This is a consequence of the fact that `R` is symmetric positive definite with + 1s on the main diagonal. + + We choose the mapping from x in `R^{m}` to `R^{n^2}` where `m` is the + `(n - 1)`th triangular number; i.e. `m = 1 + 2 + ... + (n - 1)`. + + L_ij = x_i,j / s_i (for i < j) + L_ii = 1 / s_i + + where s_i = sqrt(1 + x_i,0^2 + x_i,1^2 + ... + x_(i,i-1)^2). We can check that + the required constraints on the image are satisfied. #### Examples diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 55de3249cc..8a50e7ffe7 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -704,25 +704,38 @@ def test_lkjcorr_transform_round_trip(): def test_lkjcorr_log_jac_det(): """ - Verify that the computed log determinant of the Jacobian matches the expected closed-form solution. + Verify that the computed log determinant of the Jacobian matches the expected value + obtained from PyTensor's automatic differentiation with a non-trivial input. """ n = 3 transform = CholeskyCorr(n=n) - # Create a sample unconstrained vector (all zeros for simplicity) - x = np.zeros(int(n * (n - 1) / 2), dtype=pytensor.config.floatX) + # Create a non-trivial sample unconstrained vector + x = np.random.randn(int(n * (n - 1) / 2)).astype(pytensor.config.floatX) x_tensor = pt.as_tensor_variable(x) # Perform forward transform to obtain Cholesky factors - y = transform.forward(x_tensor).eval() + y = transform.forward(x_tensor) # Compute the log determinant using the transform's method computed_log_jac_det = transform.log_jac_det(y).eval() - # Expected log determinant: 0 (since row norms are 1) - expected_log_jac_det = 0.0 + # Define the backward function + backward = transform.backward + + # Compute the Jacobian matrix using PyTensor's automatic differentiation + backward_transformed = backward(y) + jacobian_matrix = pt.jacobian(backward_transformed, y) + + # Compile the function to compute the Jacobian matrix + jacobian_func = pytensor.function([], jacobian_matrix) + jacobian_val = jacobian_func() + + # Compute the log determinant of the Jacobian matrix + actual_log_jac_det = np.log(np.abs(np.linalg.det(jacobian_val))) - assert_allclose(computed_log_jac_det, expected_log_jac_det, atol=1e-6) + # Compare the two + assert_allclose(computed_log_jac_det, actual_log_jac_det, atol=1e-6) @pytest.mark.parametrize("n", [2, 4, 5]) From cf8d9a8e2c9cb3a4cdb94cb17e16ee24742fe4b3 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 4 Oct 2025 18:47:11 -0500 Subject: [PATCH 05/10] Finalize transformation --- pymc/distributions/transforms.py | 341 +++++++++++++++++--------- tests/distributions/test_transform.py | 212 ++++++---------- 2 files changed, 301 insertions(+), 252 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 25064a0ba1..465b3fc46e 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -19,7 +19,7 @@ from pytensor.graph import Op from pytensor.npy_2_compat import normalize_axis_tuple -from pytensor.tensor import TensorVariable +from pytensor.tensor import TensorLike, TensorVariable from pymc.logprob.transforms import ( ChainedTransform, @@ -33,10 +33,15 @@ __all__ = [ "Chain", + "Chain", + "CholeskyCorr", + "CholeskyCovPacked", "CholeskyCovPacked", "Interval", "Transform", "ZeroSumTransform", + "ZeroSumTransform", + "circular", "circular", "log", "log_exp_m1", @@ -44,11 +49,6 @@ "ordered", "simplex", "sum_to_1", - "circular", - "CholeskyCorr", - "CholeskyCovPacked", - "Chain", - "ZeroSumTransform", ] @@ -142,171 +142,286 @@ def log_jac_det(self, value, *inputs): class CholeskyCorr(Transform): """ - Transforms unconstrained real numbers to the off-diagonal elements of - a Cholesky decomposition of a correlation matrix. + Map an unconstrained real vector the Cholesky factor of a correlation matrix. - This ensures that the resulting correlation matrix is positive definite. + For detailed description of the transform, [1]_ and [2]_. - #### Mathematical Details + This is typically used with :class:`~pymc.distributions.LKJCholeskyCov` to place priors on correlation structures. + For a related transform that additionally rescales diagonal elements (working on covariance factors), see + :class:`~pymc.distributions.transforms.CholeskyCovPacked`. - This bijector provides a change of variables from unconstrained reals to a - parameterization of the CholeskyLKJ distribution. The CholeskyLKJ distribution - [1] is a distribution on the set of Cholesky factors of positive definite - correlation matrices. The CholeskyLKJ probability density function is - obtained from the LKJ density on n x n matrices as follows: + Adapted from the implementation in TensorFlow Probability [3]_: + https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31 + + Examples + -------- - 1 = int p(A | eta) dA - = int Z(eta) * det(A) ** (eta - 1) dA - = int Z(eta) L_ii ** {(n - i - 1) + 2 * (eta - 1)} ^dL_ij (0 <= i < j < n) + .. code-block:: python - where Z(eta) is the normalizer; the matrix L is the Cholesky factor of the - correlation matrix A; and ^dL_ij denotes the wedge product (or differential) - of the strictly lower triangular entries of L. The entries L_ij are - constrained such that each entry lies in [-1, 1] and the norm of each row is - 1. The norm includes the diagonal; which is not included in the wedge product. - To preserve uniqueness, we further specify that the diagonal entries are - positive. + import numpy as np + import pytensor.tensor as pt + from pymc.distributions.transforms import CholeskyCorr - The image of unconstrained reals under the `CorrelationCholesky` bijector is - the set of correlation matrices which are positive definite. A [correlation - matrix](https://en.wikipedia.org/wiki/Correlation_and_dependence#Correlation_matrices) - can be characterized as a symmetric positive semidefinite matrix with 1s on - the main diagonal. + unconstrained_vector = pt.as_tensor(np.array([2.0, 2.0, 1.0])) + n = unconstrained_vector.shape[0] + tr = CholeskyCorr(n) + constrained_matrix = tr.forward(unconstrained_vector) + y.eval() + array( + [[1.0, 0.0, 0.0], [0.70710678, 0.70710678, 0.0], [0.66666667, 0.66666667, 0.33333333]] + ) - For a lower triangular matrix `L` to be a valid Cholesky-factor of a positive - definite correlation matrix, it is necessary and sufficient that each row of - `L` have unit Euclidean norm [1]. To see this, observe that if `L_i` is the - `i`th row of the Cholesky factor corresponding to the correlation matrix `R`, - then the `i`th diagonal entry of `R` satisfies: + References + ---------- + .. [1] Lewandowski, D., Kurowicka, D., & Joe, H. (2009). + Generating random correlation matrices based on vines and extended onion method. + Journal of Multivariate Analysis, 100(9), 1989–2001. + .. [2] Stan Development Team. Stan Functions Reference. Section on LKJ / Cholesky correlation. + .. [3] TensorFlow Probability. Correlation Cholesky bijector implementation. + https://github.com/tensorflow/probability/ + """ - 1 = R_i,i = L_i . L_i = ||L_i||^2 + name = "cholesky-corr" - where '.' is the dot product of vectors and `||...||` denotes the Euclidean - norm. + def __init__(self, n, upper: bool = False): + """ + Initialize the CholeskyCorr transform. - Furthermore, observe that `R_i,j` lies in the interval `[-1, 1]`. By the - Cauchy-Schwarz inequality: + Parameters + ---------- + n : int + Size of the correlation matrix. + upper: bool, default False + If True, transform to an upper triangular matrix. If False, transform to a lower triangular matrix. + """ + self.n = n + self.m = (n * (n + 1)) // 2 # Number of triangular elements + self.upper = upper - |R_i,j| = |L_i . L_j| <= ||L_i|| ||L_j|| = 1 + super().__init__() - This is a consequence of the fact that `R` is symmetric positive definite with - 1s on the main diagonal. + def _fill_triangular_spiral( + self, x_raveled: TensorLike, unit_diag: bool = True + ) -> TensorVariable: + """ + Create a triangular matrix from a vector by filling it in a spiral order. - We choose the mapping from x in `R^{m}` to `R^{n^2}` where `m` is the - `(n - 1)`th triangular number; i.e. `m = 1 + 2 + ... + (n - 1)`. + This code is adapted from the `fill_triangular` function in TensorFlow Probability: + https://github.com/tensorflow/probability/blob/a26f4cbe5ce1549767e13798d9bf5032dac4257b/tensorflow_probability/python/math/linalg.py#L925 - L_ij = x_i,j / s_i (for i < j) - L_ii = 1 / s_i + Parameters + ---------- + x_raveled: TensorLike + The input vector to be reshaped into a triangular matrix. + unit_diag: bool, default False + If True, the diagonal elements are assumed to be 1 and are not filled from the input vector. The input + vector is expected to have length m = n * (n - 1) / 2 in this case, containing only the off-diagonal + elements. - where s_i = sqrt(1 + x_i,0^2 + x_i,1^2 + ... + x_(i,i-1)^2). We can check that - the required constraints on the image are satisfied. + Returns + ------- + triangular_matrix: TensorVariable + The resulting triangular matrix. - #### Examples + Notes + ----- + By "spiral order", it is meant that the matrix is filled by jumping between the top and bottom rows, flipping + the fill order from left-to-right to right-to-left on each jump. For example, to fill a 4x4 matrix with + `order=True`, the matrix is filled in the following order: - ```python - transform = CholeskyCorr(n=3) - x = pt.as_tensor_variable([0.0, 0.0, 0.0]) - y = transform.forward(x).eval() - # y will be the off-diagonal elements of the Cholesky factor + - Row 0, left to right + - Row 3, right to left + - Row 1, left to right + - Row 2, right to left - x_reconstructed = transform.backward(y).eval() - # x_reconstructed should closely match the original x - ``` + When `upper` if False, everything is reversed: - #### References - - [Stan Manual. Section 24.2. Cholesky LKJ Correlation Distribution.](https://mc-stan.org/docs/2_18/functions-reference/cholesky-lkj-correlation-distribution.html) - - Lewandowski, D., Kurowicka, D., & Joe, H. (2009). "Generating random correlation matrices based on vines and extended onion method." *Journal of Multivariate Analysis, 100*(5), 1989-2001. - """ + - Row 3, right to left + - Row 0, left to right + - Row 2, right to left + - Row 1, left to right - name = "cholesky-corr" + After filling, entries not part of the triangular matrix are set to zero. - def __init__(self, n, validate_args=False): + Examples + -------- + + .. code-block:: python + + import numpy as np + from pymc.distributions.transforms import CholeskyCorr + + tr = CholeskyCorr(n=4) + x_unconstrained = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + tr._fill_triangular_spiral(x_unconstrained, upper=False).eval() + + # Out: + # array([[ 5, 0, 0, 0], + # [ 9, 10, 0, 0], + # [ 8, 7, 6, 0], + # [ 4, 3, 2, 1]]) """ - Initialize the CholeskyCorr transform. + x_raveled = pt.as_tensor(x_raveled) + *batch_shape, _ = x_raveled.shape + n, m = self.n, self.m + upper = self.upper + + if unit_diag: + m -= n + n -= 1 + + tail = x_raveled[..., n:] + + if upper: + xc = pt.concatenate([x_raveled, pt.flip(tail, -1)]) + else: + xc = pt.concatenate([tail, pt.flip(x_raveled, -1)]) + + y = pt.reshape(xc, (*batch_shape, n, n)) + return pt.triu(y) if upper else pt.tril(y) + + def _inverse_fill_triangular_spiral( + self, x: TensorLike, unit_diag: bool = True + ) -> TensorVariable: + """ + Inverse operation of `_fill_triangular_spiral`. + + Extracts the elements of a triangular matrix in spiral order and returns them as a vector. For details about + what is meant by "spiral order", see the docstring of `_fill_triangular_spiral`. Parameters ---------- - n : int - Size of the correlation matrix. - validate_args : bool, default False - Whether to validate input arguments. + x: TensorVariable + The input triangular matrix. + unit_diag: bool + If True, the diagonal elements are assumed to be 1 and are not included in the output vector. + + Returns + ------- + x_raveled: TensorVariable + The resulting vector containing the elements of the triangular matrix in spiral order. """ - self.n = n - self.m = int(n * (n - 1) / 2) # Number of off-diagonal elements - self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() - self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() - super().__init__(validate_args=validate_args) + x = pt.as_tensor(x) + *batch_shape, _, _ = x.shape + n, m = self.n, self.m + + if unit_diag: + m -= n + n -= 1 - def _generate_tril_indices(self): - row_indices, col_indices = np.tril_indices(self.n, -1) - return (row_indices, col_indices) + upper = self.upper - def _generate_triu_indices(self): - row_indices, col_indices = np.triu_indices(self.n, 1) - return (row_indices, col_indices) + if upper: + initial_elements = x[..., 0, :] + triangular_portion = x[..., 1:, :] + else: + initial_elements = pt.flip(x[..., -1, :], axis=-1) + triangular_portion = x[..., :-1, :] + + rotated_triangular_portion = pt.flip(triangular_portion, axis=(-1, -2)) + consolidated_matrix = triangular_portion + rotated_triangular_portion + end_sequence = pt.reshape( + consolidated_matrix, + (*batch_shape, pt.cast(n * (n - 1), "int64")), + ) + y = pt.concatenate([initial_elements, end_sequence[..., : m - n]], axis=-1) + + return y - def forward(self, x, *inputs): + def forward(self, chol_corr_matrix: TensorLike, *inputs): """ - Forward transform: Unconstrained real numbers to Cholesky factors. + Transform the Cholesky factor of a correlation matrix into a real-valued vector. Parameters ---------- - x : tensor - Unconstrained real numbers. + chol_corr_matrix : TensorVariable + Cholesky factor of a correlation matrix R = L @ L.T of shape (n,n). + inputs: + Additional input values. Not used; included for signature compatibility with other transformations. Returns ------- - tensor - Transformed Cholesky factors. + unconstrained_vector: TensorVariable + Real-valued vector of length m = n * (n - 1) / 2. """ - # Initialize a zero matrix - chol = pt.zeros((self.n, self.n), dtype=x.dtype) + chol_corr_matrix = pt.as_tensor(chol_corr_matrix) + n = self.n + + # Extract the reciprocal of the row norms from the diagonal. + diag = pt.diagonal(chol_corr_matrix, axis1=-2, axis2=-1)[..., None] - # Assign the unconstrained values to the lower triangular part - chol = pt.set_subtensor(chol[self.tril_r_idxs, self.tril_c_idxs], x) + # Set the diagonal to 0s. + diag_idx = pt.arange(n) + chol_corr_matrix = chol_corr_matrix[..., diag_idx, diag_idx].set(0) - # Normalize each row to have unit L2 norm - row_norms = pt.sqrt(pt.sum(chol**2, axis=1, keepdims=True)) - chol = chol / row_norms + # Multiply with the norm (or divide by its reciprocal) to recover the + # unconstrained reals in the (strictly) lower triangular part. + unconstrained_matrix = chol_corr_matrix / diag - return chol[self.tril_r_idxs, self.tril_c_idxs] + # Remove the first row and last column before inverting the fill_triangular_spiral + # transformation. + return self._inverse_fill_triangular_spiral( + unconstrained_matrix[..., 1:, :-1], unit_diag=True + ) - def backward(self, y, *inputs): + def backward(self, unconstrained_vector: TensorLike, *inputs): """ - Backward transform: Cholesky factors to unconstrained real numbers. + Transform a real-valued vector of length m = n * (n - 1) / 2 into the Cholesky factor of a correlation matrix. Parameters ---------- - y : tensor - Cholesky factors. + unconstrained_vector : TensorLike + Real-valued vector of length m = n * (n - 1) / 2. + inputs: + Additional input values. Not used; included for signature compatibility with other transformations. Returns ------- - tensor + unconstrained_vector: TensorVariable Unconstrained real numbers. """ - # Reconstruct the full Cholesky matrix - chol = pt.zeros((self.n, self.n), dtype=y.dtype) - chol = pt.set_subtensor(chol[self.triu_r_idxs, self.triu_c_idxs], y) - chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype) + unconstrained_vector = pt.as_tensor(unconstrained_vector) + chol_corr_matrix = self._fill_triangular_spiral(unconstrained_vector, unit_diag=True) - # Perform Cholesky decomposition - chol = pt.linalg.cholesky(chol) + # Pad zeros on the top row and right column. + ndim = chol_corr_matrix.ndim + paddings = [*([(0, 0)] * (ndim - 2)), [1, 0], [0, 1]] + chol_corr_matrix = pt.pad(chol_corr_matrix, paddings) - # Extract the unconstrained parameters by normalizing - row_norms = pt.sqrt(pt.sum(chol**2, axis=1)) - unconstrained = chol / row_norms[:, None] + diag_idx = pt.arange(self.n) + chol_corr_matrix = chol_corr_matrix[..., diag_idx, diag_idx].set(1) - return unconstrained[self.tril_r_idxs, self.tril_c_idxs] + # Normalize each row to have Euclidean (L2) norm 1. + chol_corr_matrix /= pt.linalg.norm(chol_corr_matrix, axis=-1, ord=2)[..., None] - def log_jac_det(self, y, *inputs): + return chol_corr_matrix + + def log_jac_det(self, unconstrained_vector: TensorLike, *inputs) -> TensorVariable: """ Compute the log determinant of the Jacobian. - The Jacobian determinant for normalization is the product of row norms. + Parameters + ---------- + unconstrained_vector : TensorLike + Real-valued vector of length m = n * (n - 1) / 2. + inputs: + Additional input values. Not used; included for signature compatibility with other transformations. + + Returns + ------- + log_jac_det: TensorVariable + Log determinant of the Jacobian of the transformation. """ - row_norms = pt.sqrt(pt.sum(y**2, axis=1)) - return -pt.sum(pt.log(row_norms), axis=-1) + chol_corr_matrix = self.backward(unconstrained_vector, *inputs) + n = self.n + input_dtype = unconstrained_vector.dtype + + # TODO: tfp has a negative sign here; verify if it is needed + return pt.sum( + pt.arange(2, 2 + n, dtype=input_dtype) + * pt.log(pt.diagonal(chol_corr_matrix, axis1=-2, axis2=-1)), + axis=-1, + ) class CholeskyCovPacked(Transform): diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 8a50e7ffe7..f640d5f854 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -24,7 +24,6 @@ import pymc as pm import pymc.distributions.transforms as tr -from pymc.distributions.transforms import CholeskyCorr from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import Transform from pymc.pytensorf import floatX, jacobian @@ -668,153 +667,88 @@ def log_jac_det(self, value, *inputs): m.logp(jacobian=jacobian_val) -def test_lkjcorr_transform_round_trip(): - """ - Test that applying the forward transform followed by the backward transform - retrieves the original unconstrained parameters, and that sampled matrices are positive definite. - """ - with pm.Model() as model: - rho = pm.LKJCorr("rho", n=3, eta=2) - - trace = pm.sample( - 100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False - ) - - # Extract the sampled correlation matrices - rho_samples = trace["rho"] - num_samples = rho_samples.shape[0] - - for i in range(num_samples): - sample_matrix = rho_samples[i] - - # Check if the sampled matrix is positive definite - try: - np.linalg.cholesky(sample_matrix) - except np.linalg.LinAlgError: - pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.") - - # Perform round-trip transform: forward and then backward - transform = CholeskyCorr(n=3) - unconstrained = transform.forward(pt.as_tensor_variable(sample_matrix)).eval() - reconstructed = transform.backward(unconstrained).eval() - - # Assert that the original and reconstructed unconstrained parameters are close - assert_allclose(sample_matrix, reconstructed, atol=1e-6) - - -def test_lkjcorr_log_jac_det(): - """ - Verify that the computed log determinant of the Jacobian matches the expected value - obtained from PyTensor's automatic differentiation with a non-trivial input. - """ - n = 3 - transform = CholeskyCorr(n=n) - - # Create a non-trivial sample unconstrained vector - x = np.random.randn(int(n * (n - 1) / 2)).astype(pytensor.config.floatX) - x_tensor = pt.as_tensor_variable(x) - - # Perform forward transform to obtain Cholesky factors - y = transform.forward(x_tensor) - - # Compute the log determinant using the transform's method - computed_log_jac_det = transform.log_jac_det(y).eval() - - # Define the backward function - backward = transform.backward - - # Compute the Jacobian matrix using PyTensor's automatic differentiation - backward_transformed = backward(y) - jacobian_matrix = pt.jacobian(backward_transformed, y) - - # Compile the function to compute the Jacobian matrix - jacobian_func = pytensor.function([], jacobian_matrix) - jacobian_val = jacobian_func() - - # Compute the log determinant of the Jacobian matrix - actual_log_jac_det = np.log(np.abs(np.linalg.det(jacobian_val))) - - # Compare the two - assert_allclose(computed_log_jac_det, actual_log_jac_det, atol=1e-6) - - -@pytest.mark.parametrize("n", [2, 4, 5]) -def test_lkjcorr_transform_various_sizes(n): - """ - Test the CholeskyCorr transform with various sizes of correlation matrices. - """ - transform = CholeskyCorr(n=n) - unconstrained_size = int(n * (n - 1) / 2) - - # Generate random unconstrained real numbers - x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX) - x_tensor = pt.as_tensor_variable(x) - - # Perform forward transform - y = transform.forward(x_tensor).eval() - - # Perform backward transform - reconstructed = transform.backward(y).eval() - - # Assert that the original and reconstructed unconstrained parameters are close - assert_allclose(x, reconstructed, atol=1e-6) - - -def test_lkjcorr_invalid_n(): - """ - Test that initializing CholeskyCorr with invalid 'n' values raises appropriate errors. - """ - with pytest.raises(ValueError): - # 'n' must be an integer greater than 1 - CholeskyCorr(n=1) - - with pytest.raises(TypeError): - # 'n' must be an integer - CholeskyCorr(n="three") +class TestLJKCholeskyCorr: + def _get_test_values(self): + x_unconstrained = np.array([2.0, 2.0, 1.0]) + x_constrained = np.array( + [[1.0, 0.0, 0.0], [0.70710678, 0.70710678, 0.0], [0.66666667, 0.66666667, 0.33333333]] + ) + return x_unconstrained, x_constrained + + @pytest.mark.parametrize("upper", [True, False], ids=["upper", "lower"]) + def test_fill_triangular_spiral(self, upper): + x_unconstrained = np.array([1, 2, 3, 4, 5, 6]) + + if upper: + x_constrained = np.array( + [ + [1, 2, 3], + [0, 5, 6], + [0, 0, 4], + ] + ) + else: + x_constrained = np.array( + [ + [4, 0, 0], + [6, 5, 0], + [3, 2, 1], + ] + ) + + transform = tr.CholeskyCorr(n=3, upper=upper) + + np.testing.assert_allclose( + transform._fill_triangular_spiral(x_unconstrained, unit_diag=False).eval(), + x_constrained, + ) + np.testing.assert_allclose( + transform._inverse_fill_triangular_spiral(x_constrained, unit_diag=False).eval(), + x_unconstrained, + ) -def test_lkjcorr_positive_definite(): - """ - Ensure that all sampled correlation matrices are positive definite. - """ - with pm.Model() as model: - rho = pm.LKJCorr("rho", n=4, eta=2) + def test_forward(self): + transform = tr.CholeskyCorr(n=3, upper=False) + x_unconstrained, x_constrained = self._get_test_values() - trace = pm.sample( - 100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False - ) + np.testing.assert_allclose( + transform.forward(x_constrained).eval(), + x_unconstrained, + atol=1e-6, + ) - # Extract the sampled correlation matrices - rho_samples = trace["rho"] - num_samples = rho_samples.shape[0] + def test_backward(self): + transform = tr.CholeskyCorr(n=3, upper=False) + x_unconstrained, x_constrained = self._get_test_values() - for i in range(num_samples): - sample_matrix = rho_samples[i] + np.testing.assert_allclose( + transform.backward(x_unconstrained).eval(), + x_constrained, + atol=1e-6, + ) - # Check if the sampled matrix is positive definite - try: - np.linalg.cholesky(sample_matrix) - except np.linalg.LinAlgError: - pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.") + def test_transform_round_trip(self): + transform = tr.CholeskyCorr(n=3, upper=False) + x_unconstrained, x_constrained = self._get_test_values() + constrained_reconstructed = transform.backward(transform.forward(x_constrained)).eval() + unconstrained_reconstructed = transform.forward(transform.backward(x_unconstrained)).eval() -def test_lkjcorr_round_trip_various_sizes(): - """ - Perform round-trip transformation tests for various sizes of correlation matrices. - """ - for n in [2, 3, 4]: - transform = CholeskyCorr(n=n) - unconstrained_size = int(n * (n - 1) / 2) + np.testing.assert_allclose(x_unconstrained, unconstrained_reconstructed, atol=1e-6) + np.testing.assert_allclose(x_constrained, constrained_reconstructed, atol=1e-6) - # Generate random unconstrained real numbers - x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX) - x_tensor = pt.as_tensor_variable(x) + def test_log_jac_det(self): + transform = tr.CholeskyCorr(n=3, upper=False) + x_unconstrained, x_constrained = self._get_test_values() - # Perform forward transform - y = transform.forward(x_tensor).eval() + computed_log_jac_det = transform.log_jac_det(x_unconstrained).eval() - # Perform backward transform - reconstructed = transform.backward(y).eval() + x = pt.tensor("x", shape=(3,)) + lower_tri_vec = transform.backward(x)[pt.tril_indices(x.shape[0], k=-1)].ravel() + jac = pt.jacobian(lower_tri_vec, x, vectorize=True) + _, autodiff_log_jac_det = pt.linalg.slogdet(jac) - # Assert that the original and reconstructed unconstrained parameters are close - assert_allclose(x, reconstructed, atol=1e-6) + np.testing.assert_allclose( + autodiff_log_jac_det.eval({x: x_unconstrained}), computed_log_jac_det, atol=1e-6 + ) From 68d2ff2ffd9cdb78a0ec7fa0b382581c0f9ccc09 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 4 Oct 2025 21:28:17 -0500 Subject: [PATCH 06/10] Rename transformer to CholeskyCorrTransformer --- pymc/distributions/multivariate.py | 6 +++--- pymc/distributions/transforms.py | 17 ++++++++--------- tests/distributions/test_transform.py | 12 ++++++------ 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 80d7d8f415..814e85c671 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -74,7 +74,7 @@ to_tuple, ) from pymc.distributions.transforms import ( - CholeskyCorr, + CholeskyCorrTransform, Interval, ZeroSumTransform, _default_transform, @@ -1652,9 +1652,9 @@ def logp(value, n, eta): @_default_transform.register(_LKJCorr) def lkjcorr_default_transform(op, rv): - _, _, _, n, *_ = rv.owner.inputs + rng, shape, n, eta, *_ = rv.owner.inputs = rv.owner.inputs n = pt.get_scalar_constant_value(n) # Safely extract scalar value without eval - return CholeskyCorr(n) + return CholeskyCorrTransform(n=n, upper=False) class LKJCorr: diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 465b3fc46e..17d4bd6055 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -34,7 +34,7 @@ __all__ = [ "Chain", "Chain", - "CholeskyCorr", + "CholeskyCorrTransform", "CholeskyCovPacked", "CholeskyCovPacked", "Interval", @@ -140,7 +140,7 @@ def log_jac_det(self, value, *inputs): return pt.sum(y, axis=-1) -class CholeskyCorr(Transform): +class CholeskyCorrTransform(Transform): """ Map an unconstrained real vector the Cholesky factor of a correlation matrix. @@ -181,7 +181,7 @@ class CholeskyCorr(Transform): https://github.com/tensorflow/probability/ """ - name = "cholesky-corr" + name = "cholesky_corr" def __init__(self, n, upper: bool = False): """ @@ -267,15 +267,14 @@ def _fill_triangular_spiral( upper = self.upper if unit_diag: - m -= n - n -= 1 + n = n - 1 tail = x_raveled[..., n:] if upper: - xc = pt.concatenate([x_raveled, pt.flip(tail, -1)]) + xc = pt.concatenate([x_raveled, pt.flip(tail, -1)], axis=-1) else: - xc = pt.concatenate([tail, pt.flip(x_raveled, -1)]) + xc = pt.concatenate([tail, pt.flip(x_raveled, -1)], axis=-1) y = pt.reshape(xc, (*batch_shape, n, n)) return pt.triu(y) if upper else pt.tril(y) @@ -306,8 +305,8 @@ def _inverse_fill_triangular_spiral( n, m = self.n, self.m if unit_diag: - m -= n - n -= 1 + m = m - n + n = n - 1 upper = self.upper diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index f640d5f854..d01971f6d2 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -667,7 +667,7 @@ def log_jac_det(self, value, *inputs): m.logp(jacobian=jacobian_val) -class TestLJKCholeskyCorr: +class TestLJKCholeskyCorrTransform: def _get_test_values(self): x_unconstrained = np.array([2.0, 2.0, 1.0]) x_constrained = np.array( @@ -696,7 +696,7 @@ def test_fill_triangular_spiral(self, upper): ] ) - transform = tr.CholeskyCorr(n=3, upper=upper) + transform = tr.CholeskyCorrTransform(n=3, upper=upper) np.testing.assert_allclose( transform._fill_triangular_spiral(x_unconstrained, unit_diag=False).eval(), @@ -709,7 +709,7 @@ def test_fill_triangular_spiral(self, upper): ) def test_forward(self): - transform = tr.CholeskyCorr(n=3, upper=False) + transform = tr.CholeskyCorrTransform(n=3, upper=False) x_unconstrained, x_constrained = self._get_test_values() np.testing.assert_allclose( @@ -719,7 +719,7 @@ def test_forward(self): ) def test_backward(self): - transform = tr.CholeskyCorr(n=3, upper=False) + transform = tr.CholeskyCorrTransform(n=3, upper=False) x_unconstrained, x_constrained = self._get_test_values() np.testing.assert_allclose( @@ -729,7 +729,7 @@ def test_backward(self): ) def test_transform_round_trip(self): - transform = tr.CholeskyCorr(n=3, upper=False) + transform = tr.CholeskyCorrTransform(n=3, upper=False) x_unconstrained, x_constrained = self._get_test_values() constrained_reconstructed = transform.backward(transform.forward(x_constrained)).eval() @@ -739,7 +739,7 @@ def test_transform_round_trip(self): np.testing.assert_allclose(x_constrained, constrained_reconstructed, atol=1e-6) def test_log_jac_det(self): - transform = tr.CholeskyCorr(n=3, upper=False) + transform = tr.CholeskyCorrTransform(n=3, upper=False) x_unconstrained, x_constrained = self._get_test_values() computed_log_jac_det = transform.log_jac_det(x_unconstrained).eval() From 6f7acf5b798d39b0119f041e991bfe5a8dade6f6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 4 Oct 2025 21:29:40 -0500 Subject: [PATCH 07/10] Blockwise PosDefMatrix check --- pymc/distributions/multivariate.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 814e85c671..de54965d2a 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -33,6 +33,7 @@ get_underlying_scalar_constant_value, sigmoid, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace @@ -923,10 +924,8 @@ def posdef(AA): class PosDefMatrix(Op): """Check if input is positive definite. Input should be a square matrix.""" - # Properties attribute __props__ = () - - # Compulsory if itypes and otypes are not defined + gufunc_signature = "(m,m)->()" def make_node(self, x): x = pt.as_tensor_variable(x) @@ -934,7 +933,6 @@ def make_node(self, x): o = TensorType(dtype="bool", shape=[])() return Apply(self, [x], [o]) - # Python implementation: def perform(self, node, inputs, outputs): (x,) = inputs (z,) = outputs @@ -955,7 +953,7 @@ def __str__(self): return "MatrixIsPositiveDefinite" -matrix_pos_def = PosDefMatrix() +matrix_pos_def = Blockwise(PosDefMatrix()) class WishartRV(RandomVariable): From 4a0f717b2ca86e3d955e7ddd1dccce99517a5cec Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 4 Oct 2025 21:30:50 -0500 Subject: [PATCH 08/10] LKJCorr is always a matrix --- pymc/distributions/multivariate.py | 108 +++++++++-------------- tests/distributions/test_multivariate.py | 58 ++++++------ 2 files changed, 70 insertions(+), 96 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index de54965d2a..90f3e75d26 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -34,6 +34,7 @@ sigmoid, ) from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.einsum import _delta from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace @@ -76,7 +77,6 @@ ) from pymc.distributions.transforms import ( CholeskyCorrTransform, - Interval, ZeroSumTransform, _default_transform, ) @@ -1157,12 +1157,12 @@ def _lkj_normalizing_constant(eta, n): if not isinstance(n, int): raise NotImplementedError("n must be an integer") if eta == 1: - result = gammaln(2.0 * pt.arange(1, int((n - 1) / 2) + 1)).sum() + result = gammaln(2.0 * pt.arange(1, ((n - 1) / 2) + 1)).sum() if n % 2 == 1: result += ( 0.25 * (n**2 - 1) * pt.log(np.pi) - 0.25 * (n - 1) ** 2 * pt.log(2.0) - - (n - 1) * gammaln(int((n + 1) / 2)) + - (n - 1) * gammaln((n + 1) / 2) ) else: result += ( @@ -1504,7 +1504,7 @@ def helper_deterministics(cls, n, packed_chol): class LKJCorrRV(SymbolicRandomVariable): name = "lkjcorr" - extended_signature = "[rng],[size],(),()->[rng],(n)" + extended_signature = "[rng],[size],(),()->[rng],(n,n)" _print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}") def make_node(self, rng, size, n, eta): @@ -1532,23 +1532,13 @@ def rv_op(cls, n: int, eta, *, rng=None, size=None): flat_size = pt.prod(size, dtype="int64") next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) - - triu_idx = pt.triu_indices(n, k=1) - samples = C[..., triu_idx[0], triu_idx[1]] - - if rv_size_is_none(size): - samples = samples[0] - else: - dist_shape = (n * (n - 1)) // 2 - samples = pt.reshape(samples, (*size, dist_shape)) + C = C[0] if rv_size_is_none(size) else C.reshape((*size, n, n)) return cls( inputs=[rng, size, n, eta], - outputs=[next_rng, samples], + outputs=[next_rng, C], )(rng, size, n, eta) - return samples - @classmethod def _random_corr_matrix( cls, rng: Variable, n: int, eta: TensorVariable, flat_size: TensorVariable @@ -1565,6 +1555,7 @@ def _random_corr_matrix( P = P[..., 0, 1].set(r12) P = P[..., 1, 1].set(pt.sqrt(1.0 - r12**2)) n = get_underlying_scalar_constant_value(n) + for mp1 in range(2, n): beta -= 0.5 next_rng, y = pt.random.beta( @@ -1577,17 +1568,10 @@ def _random_corr_matrix( P = P[..., 0:mp1, mp1].set(pt.sqrt(y[..., np.newaxis]) * z) P = P[..., mp1, mp1].set(pt.sqrt(1.0 - y)) C = pt.einsum("...ji,...jk->...ik", P, P.copy()) - return next_rng, C - -class MultivariateIntervalTransform(Interval): - name = "interval" - - def log_jac_det(self, *args): - return super().log_jac_det(*args).sum(-1) + return next_rng, C -# Returns list of upper triangular values class _LKJCorr(BoundedContinuous): rv_type = LKJCorrRV rv_op = LKJCorrRV.rv_op @@ -1598,10 +1582,15 @@ def dist(cls, n, eta, **kwargs): eta = pt.as_tensor_variable(eta) return super().dist([n, eta], **kwargs) - def support_point(rv, *args): - return pt.zeros_like(rv) + @staticmethod + def support_point(rv: TensorVariable, *args): + ndim = rv.ndim - def logp(value, n, eta): + # Batched identity matrix + return _delta(rv.shape, (ndim - 2, ndim - 1)).astype(int) + + @staticmethod + def logp(value: TensorVariable, n, eta): """ Calculate logp of LKJ distribution at specified value. @@ -1614,31 +1603,20 @@ def logp(value, n, eta): ------- TensorVariable """ - if value.ndim > 1: - raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)") - - # TODO: PyTensor does not have a `triu_indices`, so we can only work with constant - # n (or else find a different expression) + # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants try: n = int(get_underlying_scalar_constant_value(n)) except NotScalarConstantError: raise NotImplementedError("logp only implemented for constant `n`") - shape = n * (n - 1) // 2 - tri_index = np.zeros((n, n), dtype="int32") - tri_index[np.triu_indices(n, k=1)] = np.arange(shape) - tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape) - - value = pt.take(value, tri_index) - value = pt.fill_diagonal(value, 1) - - # TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants try: eta = float(get_underlying_scalar_constant_value(eta)) except NotScalarConstantError: raise NotImplementedError("logp only implemented for constant `eta`") + result = _lkj_normalizing_constant(eta, n) result += (eta - 1.0) * pt.log(det(value)) + return check_parameters( result, value >= -1, @@ -1675,10 +1653,6 @@ class LKJCorr: The shape parameter (eta > 0) of the LKJ distribution. eta = 1 implies a uniform distribution of the correlation matrices; larger values put more weight on matrices with few correlations. - return_matrix : bool, default=False - If True, returns the full correlation matrix. - False only returns the values of the upper triangular matrix excluding - diagonal in a single vector of length n(n-1)/2 for memory efficiency Notes ----- @@ -1693,7 +1667,7 @@ class LKJCorr: # Define the vector of fixed standard deviations sds = 3 * np.ones(10) - corr = pm.LKJCorr("corr", eta=4, n=10, return_matrix=True) + corr = pm.LKJCorr("corr", eta=4, n=10) # Define a new MvNormal with the given correlation matrix vals = sds * pm.MvNormal("vals", mu=np.zeros(10), cov=corr, shape=10) @@ -1703,10 +1677,6 @@ class LKJCorr: chol = pt.linalg.cholesky(corr) vals = sds * pt.dot(chol, vals_raw) - # The matrix is internally still sampled as a upper triangular vector - # If you want access to it in matrix form in the trace, add - pm.Deterministic("corr_mat", corr) - References ---------- @@ -1716,26 +1686,28 @@ class LKJCorr: 100(9), pp.1989-2001. """ - def __new__(cls, name, n, eta, *, return_matrix=False, **kwargs): - c_vec = _LKJCorr(name, eta=eta, n=n, **kwargs) - if not return_matrix: - return c_vec - else: - return cls.vec_to_corr_mat(c_vec, n) - - @classmethod - def dist(cls, n, eta, *, return_matrix=False, **kwargs): - c_vec = _LKJCorr.dist(eta=eta, n=n, **kwargs) - if not return_matrix: - return c_vec - else: - return cls.vec_to_corr_mat(c_vec, n) + def __new__(cls, name, n, eta, **kwargs): + return_matrix = kwargs.pop("return_matrix", None) + if return_matrix is not None: + warnings.warn( + "The `return_matrix` argument is deprecated and has no effect. " + "LKJCorr always returns the correlation matrix.", + DeprecationWarning, + stacklevel=2, + ) + return _LKJCorr(name, eta=eta, n=n, **kwargs) @classmethod - def vec_to_corr_mat(cls, vec, n): - tri = pt.zeros(pt.concatenate([vec.shape[:-1], (n, n)])) - tri = pt.subtensor.set_subtensor(tri[(..., *np.triu_indices(n, 1))], vec) - return tri + pt.moveaxis(tri, -2, -1) + pt.diag(pt.ones(n)) + def dist(cls, n, eta, **kwargs): + return_matrix = kwargs.pop("return_matrix", None) + if return_matrix is not None: + warnings.warn( + "The `return_matrix` argument is deprecated and has no effect. " + "LKJCorr always returns the correlation matrix.", + DeprecationWarning, + stacklevel=2, + ) + return _LKJCorr.dist(eta=eta, n=n, **kwargs) class MatrixNormalRV(RandomVariable): diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 2597605bd1..824aae1ae5 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -35,7 +35,6 @@ from pymc import Model from pymc.distributions.multivariate import ( - MultivariateIntervalTransform, _LKJCholeskyCov, _LKJCorr, _OrderedMultinomial, @@ -43,6 +42,7 @@ quaddist_matrix, ) from pymc.distributions.shape_utils import change_dist_size, to_tuple +from pymc.distributions.transforms import CholeskyCorrTransform from pymc.logprob.basic import logp from pymc.logprob.utils import ParameterValueError from pymc.math import kronecker @@ -559,10 +559,14 @@ def test_wishart(self, n): lambda value, nu, V: st.wishart.logpdf(value, int(nu), V), ) - @pytest.mark.parametrize("x,eta,n,lp", LKJ_CASES) - def test_lkjcorr(self, x, eta, n, lp): + @pytest.mark.parametrize("x_tri,eta,n,lp", LKJ_CASES) + def test_lkjcorr(self, x_tri, eta, n, lp): with pm.Model() as model: - pm.LKJCorr("lkj", eta=eta, n=n, default_transform=None, return_matrix=False) + pm.LKJCorr("lkj", eta=eta, n=n, transform=None) + + x = np.eye(n) + x[np.tril_indices(n, -1)] = x_tri + x[np.triu_indices(n, 1)] = x_tri point = {"lkj": x} decimals = select_by_precision(float64=6, float32=4) @@ -2153,13 +2157,13 @@ class TestLKJCorr(BaseTestDistributionRandom): sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)] sizes_expected = [ - (3,), - (3,), - (1, 3), - (1, 3), - (5, 3), - (4, 5, 3), - (2, 4, 2, 3), + (3, 3), + (3, 3), + (1, 3, 3), + (1, 3, 3), + (5, 3, 3), + (4, 5, 3, 3), + (2, 4, 2, 3, 3), ] checks_to_run = [ @@ -2169,16 +2173,26 @@ class TestLKJCorr(BaseTestDistributionRandom): ] def check_draws_match_expected(self): + from pymc.distributions import CustomDist + def ref_rand(size, n, eta): shape = int(n * (n - 1) // 2) beta = eta - 1 + n / 2 - return (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2 + tril_values = (st.beta.rvs(size=(size, shape), a=beta, b=beta) - 0.5) * 2 + return tril_values # If passed as a domain, continuous_random_tester would make `n` a shared variable # But this RV needs it to be constant in order to define the inner graph + def lkj_corr_tril(n, eta, shape=None): + tril_idx = pt.tril_indices(n) + return _LKJCorr.dist(n=n, eta=eta, shape=shape)[..., tril_idx[0], tril_idx[1]] + + def SlicedLKJ(name, n, eta, *args, shape=None, **kwargs): + return CustomDist(name, n, eta, dist=lkj_corr_tril, shape=shape) + for n in (2, 10, 50): continuous_random_tester( - _LKJCorr, + SlicedLKJ, { "eta": Domain([1.0, 10.0, 100.0], edges=(None, None)), }, @@ -2188,24 +2202,12 @@ def ref_rand(size, n, eta): ) -@pytest.mark.parametrize( - argnames="shape", - argvalues=[ - (2,), - pytest.param( - (3, 2), - marks=pytest.mark.xfail( - raises=NotImplementedError, - reason="LKJCorr logp is only implemented for vector values (ndim=1)", - ), - ), - ], -) +@pytest.mark.parametrize("shape", [(2, 2), (3, 2, 2)], ids=["no_batch", "with_batch"]) def test_LKJCorr_default_transform(shape): with pm.Model() as m: x = pm.LKJCorr("x", n=2, eta=1, shape=shape, return_matrix=False) - assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform) - assert m.logp(sum=False)[0].type.shape == shape[:-1] + assert isinstance(m.rvs_to_transforms[x], CholeskyCorrTransform) + assert m.logp(sum=False)[0].type.shape == shape[:-2] class TestLKJCholeskyCov(BaseTestDistributionRandom): From 27fd53daedbc82128d7ddd5258a23c65d9ab653e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 4 Oct 2025 23:54:24 -0500 Subject: [PATCH 09/10] Don't flatten batch dims --- pymc/distributions/multivariate.py | 53 ++++++++++-------------- tests/distributions/test_multivariate.py | 21 +++++----- 2 files changed, 32 insertions(+), 42 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 90f3e75d26..39278fcc57 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -34,7 +34,6 @@ sigmoid, ) from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.einsum import _delta from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace @@ -1213,12 +1212,8 @@ def rv_op(cls, n, eta, sd_dist, *, size=None): D = sd_dist.type(name="D") # Make sd_dist opaque to OpFromGraph size = D.shape[:-1] - # We flatten the size to make operations easier, and then rebuild it - flat_size = pt.prod(size, dtype="int64") - - next_rng, C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) - D_matrix = D.reshape((flat_size, n)) - C *= D_matrix[..., :, None] * D_matrix[..., None, :] + next_rng, C = LKJCorrRV._random_corr_matrix(rng=rng, n=n, eta=eta, size=size) + C *= D[..., :, None] * D[..., None, :] tril_idx = pt.tril_indices(n, k=0) samples = pt.linalg.cholesky(C)[..., tril_idx[0], tril_idx[1]] @@ -1520,53 +1515,52 @@ def make_node(self, rng, size, n, eta): @classmethod def rv_op(cls, n: int, eta, *, rng=None, size=None): - # We flatten the size to make operations easier, and then rebuild it + # HACK: normalize_size_param doesn't handle size=() properly + if not size: + size = None + n = pt.as_tensor(n, ndim=0, dtype=int) eta = pt.as_tensor(eta, ndim=0) rng = normalize_rng_param(rng) size = normalize_size_param(size) - if rv_size_is_none(size): - flat_size = 1 - else: - flat_size = pt.prod(size, dtype="int64") + next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, size=size) - next_rng, C = cls._random_corr_matrix(rng=rng, n=n, eta=eta, flat_size=flat_size) - C = C[0] if rv_size_is_none(size) else C.reshape((*size, n, n)) - - return cls( - inputs=[rng, size, n, eta], - outputs=[next_rng, C], - )(rng, size, n, eta) + return cls(inputs=[rng, size, n, eta], outputs=[next_rng, C])(rng, size, n, eta) @classmethod def _random_corr_matrix( - cls, rng: Variable, n: int, eta: TensorVariable, flat_size: TensorVariable + cls, rng: Variable, n: int, eta: TensorVariable, size: TensorVariable ) -> tuple[Variable, TensorVariable]: # original implementation in R see: # https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r + size = () if rv_size_is_none(size) else size beta = eta - 1.0 + n / 2.0 - next_rng, beta_rvs = pt.random.beta( - alpha=beta, beta=beta, size=flat_size, rng=rng - ).owner.outputs + next_rng, beta_rvs = pt.random.beta(alpha=beta, beta=beta, size=size, rng=rng).owner.outputs r12 = 2.0 * beta_rvs - 1.0 - P = pt.full((flat_size, n, n), pt.eye(n)) + + P = pt.full((*size, n, n), pt.eye(n)) P = P[..., 0, 1].set(r12) P = P[..., 1, 1].set(pt.sqrt(1.0 - r12**2)) n = get_underlying_scalar_constant_value(n) for mp1 in range(2, n): beta -= 0.5 + next_rng, y = pt.random.beta( - alpha=mp1 / 2.0, beta=beta, size=flat_size, rng=next_rng + alpha=mp1 / 2.0, beta=beta, size=size, rng=next_rng ).owner.outputs + next_rng, z = pt.random.normal( - loc=0, scale=1, size=(flat_size, mp1), rng=next_rng + loc=0, scale=1, size=(*size, mp1), rng=next_rng ).owner.outputs - z = z / pt.sqrt(pt.einsum("ij,ij->i", z, z.copy()))[..., np.newaxis] + + ein_sig_z = "i, i->" if z.ndim == 1 else "...ij, ...ij->...i" + z = z / pt.sqrt(pt.einsum(ein_sig_z, z, z.copy()))[..., np.newaxis] P = P[..., 0:mp1, mp1].set(pt.sqrt(y[..., np.newaxis]) * z) P = P[..., mp1, mp1].set(pt.sqrt(1.0 - y)) + C = pt.einsum("...ji,...jk->...ik", P, P.copy()) return next_rng, C @@ -1584,10 +1578,7 @@ def dist(cls, n, eta, **kwargs): @staticmethod def support_point(rv: TensorVariable, *args): - ndim = rv.ndim - - # Batched identity matrix - return _delta(rv.shape, (ndim - 2, ndim - 1)).astype(int) + return pt.broadcast_to(pt.eye(rv.shape[-1]), rv.shape) @staticmethod def logp(value: TensorVariable, n, eta): diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 824aae1ae5..00c6b73114 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1312,17 +1312,17 @@ def test_kronecker_normal_support_point(self, mu, covs, size, expected): @pytest.mark.parametrize( "n, eta, size, expected", [ - (3, 1, None, np.zeros(3)), - (5, 1, None, np.zeros(10)), - pytest.param(3, 1, 1, np.zeros((1, 3))), - pytest.param(5, 1, (2, 3), np.zeros((2, 3, 10))), + (3, 1, None, np.eye(3)), + (5, 1, None, np.eye(5)), + (3, 1, (1,), np.broadcast_to(np.eye(3), (1, 3, 3))), + (5, 1, (2, 3), np.broadcast_to(np.eye(5), (2, 3, 5, 5))), ], + ids=["n=3", "n=5", "batch_1", "batch_2"], ) def test_lkjcorr_support_point(self, n, eta, size, expected): with pm.Model() as model: - pm.LKJCorr("x", n=n, eta=eta, size=size, return_matrix=False) - # LKJCorr logp is only implemented for vector values (size=None) - assert_support_point_is_expected(model, expected, check_finite_logp=size is None) + pm.LKJCorr("x", n=n, eta=eta, size=size) + assert_support_point_is_expected(model, expected, check_finite_logp=True) @pytest.mark.parametrize( "n, eta, size, expected", @@ -1466,13 +1466,12 @@ def test_with_lkjcorr_matrix( self, ): with pm.Model() as model: - corr = pm.LKJCorr("corr", n=3, eta=2, return_matrix=True) - pm.Deterministic("corr_mat", corr) - mv = pm.MvNormal("mv", 0.0, cov=corr, size=4) + corr_mat = pm.LKJCorr("corr_mat", n=3, eta=2) + mv = pm.MvNormal("mv", 0.0, cov=corr_mat, size=4) prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False) assert prior["corr_mat"].shape == (10, 3, 3) # square - assert (prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]] == 1.0).all() # 1.0 on diagonal + assert np.allclose(prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]], 1.0) # 1.0 on diagonal assert (prior["corr_mat"] == prior["corr_mat"].transpose(0, 2, 1)).all() # symmetric assert ( prior["corr_mat"].max() <= 1.0 and prior["corr_mat"].min() >= -1.0 From 1f29284921c37db4d843b94f3c97ce8d8eec1547 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 5 Oct 2025 00:20:26 -0500 Subject: [PATCH 10/10] Test fixes --- tests/distributions/test_multivariate.py | 25 ++++++++++-------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 00c6b73114..5c1a3b112f 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1471,11 +1471,15 @@ def test_with_lkjcorr_matrix( prior = pm.sample_prior_predictive(draws=10, return_inferencedata=False) assert prior["corr_mat"].shape == (10, 3, 3) # square - assert np.allclose(prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]], 1.0) # 1.0 on diagonal assert (prior["corr_mat"] == prior["corr_mat"].transpose(0, 2, 1)).all() # symmetric - assert ( - prior["corr_mat"].max() <= 1.0 and prior["corr_mat"].min() >= -1.0 - ) # constrained between -1 and 1 + + np.testing.assert_allclose( + prior["corr_mat"][:, [0, 1, 2], [0, 1, 2]], 1.0 + ) # 1.0 on diagonal + + # constrained between -1 and 1 + assert prior["corr_mat"].max() <= (1.0 + 1e-12) + assert prior["corr_mat"].min() >= (-1.0 - 1e-12) def test_issue_3758(self): np.random.seed(42) @@ -2172,8 +2176,6 @@ class TestLKJCorr(BaseTestDistributionRandom): ] def check_draws_match_expected(self): - from pymc.distributions import CustomDist - def ref_rand(size, n, eta): shape = int(n * (n - 1) // 2) beta = eta - 1 + n / 2 @@ -2182,16 +2184,9 @@ def ref_rand(size, n, eta): # If passed as a domain, continuous_random_tester would make `n` a shared variable # But this RV needs it to be constant in order to define the inner graph - def lkj_corr_tril(n, eta, shape=None): - tril_idx = pt.tril_indices(n) - return _LKJCorr.dist(n=n, eta=eta, shape=shape)[..., tril_idx[0], tril_idx[1]] - - def SlicedLKJ(name, n, eta, *args, shape=None, **kwargs): - return CustomDist(name, n, eta, dist=lkj_corr_tril, shape=shape) - for n in (2, 10, 50): continuous_random_tester( - SlicedLKJ, + _LKJCorr, { "eta": Domain([1.0, 10.0, 100.0], edges=(None, None)), }, @@ -2204,7 +2199,7 @@ def SlicedLKJ(name, n, eta, *args, shape=None, **kwargs): @pytest.mark.parametrize("shape", [(2, 2), (3, 2, 2)], ids=["no_batch", "with_batch"]) def test_LKJCorr_default_transform(shape): with pm.Model() as m: - x = pm.LKJCorr("x", n=2, eta=1, shape=shape, return_matrix=False) + x = pm.LKJCorr("x", n=2, eta=1, shape=shape) assert isinstance(m.rvs_to_transforms[x], CholeskyCorrTransform) assert m.logp(sum=False)[0].type.shape == shape[:-2]