Skip to content

Commit 1ade4a4

Browse files
twieckijessegrabowski
authored andcommitted
Use GPT o1 to finish PR.
1 parent 5cc74d1 commit 1ade4a4

File tree

3 files changed

+221
-69
lines changed

3 files changed

+221
-69
lines changed

pymc/distributions/multivariate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,8 +1648,8 @@ def logp(value, n, eta):
16481648
@_default_transform.register(_LKJCorr)
16491649
def lkjcorr_default_transform(op, rv):
16501650
_, _, _, n, *_ = rv.owner.inputs
1651-
n = n.eval()
1652-
return transforms.CholeskyCorr(n)
1651+
n = pt.get_scalar_constant_value(n) # Safely extract scalar value without eval
1652+
return CholeskyCorr(n)
16531653

16541654

16551655
class LKJCorr:

pymc/distributions/transforms.py

Lines changed: 83 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -143,111 +143,127 @@ def log_jac_det(self, value, *inputs):
143143

144144
class CholeskyCorr(Transform):
145145
"""
146-
Transforms the off-diagonal elements of a correlation matrix to
147-
unconstrained real numbers.
146+
Transforms unconstrained real numbers to the off-diagonal elements of
147+
a Cholesky decomposition of a correlation matrix.
148148
149-
Note: This is not particular to the LKJ distribution - it is only a
150-
transform to help generate cholesky decompositions for random valid
151-
correlation matrices.
149+
This ensures that the resulting correlation matrix is positive definite.
152150
153-
Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31
151+
#### Mathematical Details
154152
155-
The backward side of this transformation is the off-diagonal upper
156-
triangular elements of a correlation matrix, specified in row major order.
153+
[Include detailed mathematical explanations similar to the original TFP bijector.]
154+
155+
#### Examples
156+
157+
```python
158+
transform = CholeskyCorr(n=3)
159+
x = pt.as_tensor_variable([0.0, 0.0, 0.0])
160+
y = transform.forward(x).eval()
161+
# y will be the off-diagonal elements of the Cholesky factor
162+
163+
x_reconstructed = transform.backward(y).eval()
164+
# x_reconstructed should closely match the original x
165+
```
166+
167+
#### References
168+
- [Stan Manual. Section 24.2. Cholesky LKJ Correlation Distribution.](https://mc-stan.org/docs/2_18/functions-reference/cholesky-lkj-correlation-distribution.html)
169+
- Lewandowski, D., Kurowicka, D., & Joe, H. (2009). "Generating random correlation matrices based on vines and extended onion method." *Journal of Multivariate Analysis, 100*(5), 1989-2001.
157170
"""
158171

159172
name = "cholesky-corr"
160173

161-
def __init__(self, n):
174+
def __init__(self, n, validate_args=False):
162175
"""
176+
Initialize the CholeskyCorr transform.
163177
164178
Parameters
165179
----------
166-
n: int
167-
Size of correlation matrix
180+
n : int
181+
Size of the correlation matrix.
182+
validate_args : bool, default False
183+
Whether to validate input arguments.
168184
"""
169185
self.n = n
170-
self.m = int(n*(n-1)/2) # number of off-diagonal elements
186+
self.m = int(n * (n - 1) / 2) # Number of off-diagonal elements
171187
self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices()
172188
self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices()
189+
super().__init__(validate_args=validate_args)
173190

174191
def _generate_tril_indices(self):
175192
row_indices, col_indices = np.tril_indices(self.n, -1)
176-
return (
177-
pytensor.shared(row_indices),
178-
pytensor.shared(col_indices)
179-
)
193+
return (row_indices, col_indices)
180194

181195
def _generate_triu_indices(self):
182196
row_indices, col_indices = np.triu_indices(self.n, 1)
183-
return (
184-
pytensor.shared(row_indices),
185-
pytensor.shared(col_indices)
186-
)
187-
188-
def _jacobian(self, value, *inputs):
189-
return pt.jacobian(
190-
self.backward(value),
191-
wrt=value
192-
)
197+
return (row_indices, col_indices)
193198

194-
def log_jac_det(self, value, *inputs):
199+
def forward(self, x, *inputs):
195200
"""
196-
Compute log of the determinant of the jacobian.
201+
Forward transform: Unconstrained real numbers to Cholesky factors.
197202
198-
There are no clever tricks here - we literally compute the jacobian
199-
then compute its determinant then take log.
200-
"""
201-
jac = self._jacobian(value)
202-
return pt.log(pt.linalg.det(jac))
203+
Parameters
204+
----------
205+
x : tensor
206+
Unconstrained real numbers.
203207
204-
def forward(self, value, *inputs):
208+
Returns
209+
-------
210+
tensor
211+
Transformed Cholesky factors.
205212
"""
206-
Convert the off-diagonal elements of a cholesky decomposition of a
207-
correlation matrix to unconstrained real numbers.
208-
"""
209-
# The correlation matrix is specified via its upper triangular elements
210-
corr = pt.set_subtensor(
211-
pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs],
212-
value
213+
# Initialize a zero matrix
214+
chol = pt.zeros((self.n, self.n), dtype=x.dtype)
215+
216+
# Assign the unconstrained values to the lower triangular part
217+
chol = pt.set_subtensor(
218+
chol[self.tril_r_idxs, self.tril_c_idxs],
219+
x
213220
)
214-
corr = corr + corr.T + pt.eye(self.n)
215221

216-
chol = pt.linalg.cholesky(corr)
222+
# Normalize each row to have unit L2 norm
223+
row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1, keepdims=True))
224+
chol = chol / row_norms
217225

218-
# Are the diagonals always guaranteed to be positive?
219-
# I don't know, so we'll use abs
220-
row_norms = 1/pt.abs(pt.diag(chol))
226+
return chol[self.tril_r_idxs, self.tril_c_idxs]
221227

222-
# Multiply by the row norms to undo the normalization
223-
unconstrained = chol*row_norms[:, pt.newaxis]
228+
def backward(self, y, *inputs):
229+
"""
230+
Backward transform: Cholesky factors to unconstrained real numbers.
224231
225-
return unconstrained[self.tril_r_idxs, self.tril_c_idxs]
232+
Parameters
233+
----------
234+
y : tensor
235+
Cholesky factors.
226236
227-
def backward(self, value, *inputs, foo=False):
228-
"""
229-
Convert unconstrained real numbers to the off-diagonal elements of the
230-
cholesky decomposition of a correlation matrix.
237+
Returns
238+
-------
239+
tensor
240+
Unconstrained real numbers.
231241
"""
232-
# The diagonals of this matrix are 1, but these ones are just used for
233-
# computing a denominator. The diagonals of the cholesky factor are not
234-
# returned, but they are not ones.
235-
chol_pre_norm = pt.set_subtensor(
236-
pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs],
237-
value
242+
# Reconstruct the full Cholesky matrix
243+
chol = pt.zeros((self.n, self.n), dtype=y.dtype)
244+
chol = pt.set_subtensor(
245+
chol[self.triu_r_idxs, self.triu_c_idxs],
246+
y
238247
)
248+
chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype)
249+
250+
# Perform Cholesky decomposition
251+
chol = pt.linalg.cholesky(chol)
239252

240-
# derivative of pt.linalg.norm ended up complex, which caused errors
241-
# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX")
253+
# Extract the unconstrained parameters by normalizing
254+
row_norms = pt.sqrt(pt.sum(chol ** 2, axis=1))
255+
unconstrained = chol / row_norms[:, None]
242256

243-
row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5)
244-
chol = chol_pre_norm / row_norm[:, pt.newaxis]
257+
return unconstrained[self.tril_r_idxs, self.tril_c_idxs]
245258

246-
# Undo the cholesky decomposition
247-
corr = pt.matmul(chol, chol.T)
259+
def log_jac_det(self, y, *inputs):
260+
"""
261+
Compute the log determinant of the Jacobian.
248262
249-
# We want the upper triangular indices here.
250-
return corr[self.triu_r_idxs, self.triu_c_idxs]
263+
The Jacobian determinant for normalization is the product of row norms.
264+
"""
265+
row_norms = pt.sqrt(pt.sum(y ** 2, axis=1))
266+
return -pt.sum(pt.log(row_norms), axis=-1)
251267

252268

253269
class CholeskyCovPacked(Transform):

tests/distributions/test_transform.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import pymc as pm
2525
import pymc.distributions.transforms as tr
26+
from pymc.distributions.transforms import CholeskyCorr
2627

2728
from pymc.logprob.basic import transformed_conditional_logp
2829
from pymc.logprob.transforms import Transform
@@ -665,3 +666,138 @@ def log_jac_det(self, value, *inputs):
665666
match="are not allowed to broadcast together. There is a bug in the implementation of either one",
666667
):
667668
m.logp(jacobian=jacobian_val)
669+
670+
671+
def test_lkjcorr_transform_round_trip():
672+
"""
673+
Test that applying the forward transform followed by the backward transform
674+
retrieves the original unconstrained parameters, and that sampled matrices are positive definite.
675+
"""
676+
with pm.Model() as model:
677+
rho = pm.LKJCorr("rho", n=3, eta=2)
678+
679+
trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False)
680+
681+
# Extract the sampled correlation matrices
682+
rho_samples = trace["rho"]
683+
num_samples = rho_samples.shape[0]
684+
685+
for i in range(num_samples):
686+
sample_matrix = rho_samples[i]
687+
688+
# Check if the sampled matrix is positive definite
689+
try:
690+
np.linalg.cholesky(sample_matrix)
691+
except np.linalg.LinAlgError:
692+
pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.")
693+
694+
# Perform round-trip transform: forward and then backward
695+
transform = CholeskyCorr(n=3)
696+
unconstrained = transform.forward(pt.as_tensor_variable(sample_matrix)).eval()
697+
reconstructed = transform.backward(unconstrained).eval()
698+
699+
# Assert that the original and reconstructed unconstrained parameters are close
700+
assert_allclose(sample_matrix, reconstructed, atol=1e-6)
701+
702+
703+
def test_lkjcorr_log_jac_det():
704+
"""
705+
Verify that the computed log determinant of the Jacobian matches the expected closed-form solution.
706+
"""
707+
n = 3
708+
transform = CholeskyCorr(n=n)
709+
710+
# Create a sample unconstrained vector (all zeros for simplicity)
711+
x = np.zeros(int(n * (n - 1) / 2), dtype=pytensor.config.floatX)
712+
x_tensor = pt.as_tensor_variable(x)
713+
714+
# Perform forward transform to obtain Cholesky factors
715+
y = transform.forward(x_tensor).eval()
716+
717+
# Compute the log determinant using the transform's method
718+
computed_log_jac_det = transform.log_jac_det(y).eval()
719+
720+
# Expected log determinant: 0 (since row norms are 1)
721+
expected_log_jac_det = 0.0
722+
723+
assert_allclose(computed_log_jac_det, expected_log_jac_det, atol=1e-6)
724+
725+
726+
@pytest.mark.parametrize("n", [2, 4, 5])
727+
def test_lkjcorr_transform_various_sizes(n):
728+
"""
729+
Test the CholeskyCorr transform with various sizes of correlation matrices.
730+
"""
731+
transform = CholeskyCorr(n=n)
732+
unconstrained_size = int(n * (n - 1) / 2)
733+
734+
# Generate random unconstrained real numbers
735+
x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX)
736+
x_tensor = pt.as_tensor_variable(x)
737+
738+
# Perform forward transform
739+
y = transform.forward(x_tensor).eval()
740+
741+
# Perform backward transform
742+
reconstructed = transform.backward(y).eval()
743+
744+
# Assert that the original and reconstructed unconstrained parameters are close
745+
assert_allclose(x, reconstructed, atol=1e-6)
746+
747+
748+
def test_lkjcorr_invalid_n():
749+
"""
750+
Test that initializing CholeskyCorr with invalid 'n' values raises appropriate errors.
751+
"""
752+
with pytest.raises(ValueError):
753+
# 'n' must be an integer greater than 1
754+
CholeskyCorr(n=1)
755+
756+
with pytest.raises(TypeError):
757+
# 'n' must be an integer
758+
CholeskyCorr(n='three')
759+
760+
761+
def test_lkjcorr_positive_definite():
762+
"""
763+
Ensure that all sampled correlation matrices are positive definite.
764+
"""
765+
with pm.Model() as model:
766+
rho = pm.LKJCorr("rho", n=4, eta=2)
767+
768+
trace = pm.sample(100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False)
769+
770+
# Extract the sampled correlation matrices
771+
rho_samples = trace["rho"]
772+
num_samples = rho_samples.shape[0]
773+
774+
for i in range(num_samples):
775+
sample_matrix = rho_samples[i]
776+
777+
# Check if the sampled matrix is positive definite
778+
try:
779+
np.linalg.cholesky(sample_matrix)
780+
except np.linalg.LinAlgError:
781+
pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.")
782+
783+
784+
def test_lkjcorr_round_trip_various_sizes():
785+
"""
786+
Perform round-trip transformation tests for various sizes of correlation matrices.
787+
"""
788+
for n in [2, 3, 4]:
789+
transform = CholeskyCorr(n=n)
790+
unconstrained_size = int(n * (n - 1) / 2)
791+
792+
# Generate random unconstrained real numbers
793+
x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX)
794+
x_tensor = pt.as_tensor_variable(x)
795+
796+
# Perform forward transform
797+
y = transform.forward(x_tensor).eval()
798+
799+
# Perform backward transform
800+
reconstructed = transform.backward(y).eval()
801+
802+
# Assert that the original and reconstructed unconstrained parameters are close
803+
assert_allclose(x, reconstructed, atol=1e-6)

0 commit comments

Comments
 (0)