Skip to content

Commit 79fa542

Browse files
committed
Port TF bijector to ensure posdef LKJCorr samples
1 parent f44071b commit 79fa542

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

pymc/distributions/multivariate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1579,7 +1579,9 @@ def logp(value, n, eta):
15791579

15801580
@_default_transform.register(_LKJCorr)
15811581
def lkjcorr_default_transform(op, rv):
1582-
return MultivariateIntervalTransform(-1.0, 1.0)
1582+
_, _, _, n, *_ = rv.owner.inputs
1583+
n = n.eval()
1584+
return transforms.CholeskyCorr(n)
15831585

15841586

15851587
class LKJCorr:

pymc/distributions/transforms.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
import pytensor.tensor as pt
20+
import pytensor
2021

2122

2223
# ignore mypy error because it somehow considers that
@@ -45,6 +46,7 @@
4546
"log",
4647
"sum_to_1",
4748
"circular",
49+
"CholeskyCorr",
4850
"CholeskyCovPacked",
4951
"Chain",
5052
"ZeroSumTransform",
@@ -138,6 +140,115 @@ def log_jac_det(self, value, *inputs):
138140
return pt.sum(y, axis=-1)
139141

140142

143+
class CholeskyCorr(Transform):
144+
"""
145+
Transforms the off-diagonal elements of a correlation matrix to
146+
unconstrained real numbers.
147+
148+
Note: This is not particular to the LKJ distribution - it is only a
149+
transform to help generate cholesky decompositions for random valid
150+
correlation matrices.
151+
152+
Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31
153+
154+
The backward side of this transformation is the off-diagonal upper
155+
triangular elements of a correlation matrix, specified in row major order.
156+
"""
157+
158+
name = "cholesky-corr"
159+
160+
def __init__(self, n):
161+
"""
162+
163+
Parameters
164+
----------
165+
n: int
166+
Size of correlation matrix
167+
"""
168+
self.n = n
169+
self.m = int(n*(n-1)/2) # number of off-diagonal elements
170+
self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices()
171+
self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices()
172+
173+
def _generate_tril_indices(self):
174+
row_indices, col_indices = np.tril_indices(self.n, -1)
175+
return (
176+
pytensor.shared(row_indices),
177+
pytensor.shared(col_indices)
178+
)
179+
180+
def _generate_triu_indices(self):
181+
row_indices, col_indices = np.triu_indices(self.n, 1)
182+
return (
183+
pytensor.shared(row_indices),
184+
pytensor.shared(col_indices)
185+
)
186+
187+
def _jacobian(self, value, *inputs):
188+
return pt.jacobian(
189+
self.backward(value),
190+
wrt=value
191+
)
192+
193+
def log_jac_det(self, value, *inputs):
194+
"""
195+
Compute log of the determinant of the jacobian.
196+
197+
There are no clever tricks here - we literally compute the jacobian
198+
then compute its determinant then take log.
199+
"""
200+
jac = self._jacobian(value)
201+
return pt.log(pt.linalg.det(jac))
202+
203+
def forward(self, value, *inputs):
204+
"""
205+
Convert the off-diagonal elements of a cholesky decomposition of a
206+
correlation matrix to unconstrained real numbers.
207+
"""
208+
# The correlation matrix is specified via its upper triangular elements
209+
corr = pt.set_subtensor(
210+
pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs],
211+
value
212+
)
213+
corr = corr + corr.T + pt.eye(self.n)
214+
215+
chol = pt.linalg.cholesky(corr)
216+
217+
# Are the diagonals always guaranteed to be positive?
218+
# I don't know, so we'll use abs
219+
row_norms = 1/pt.abs(pt.diag(chol))
220+
221+
# Multiply by the row norms to undo the normalization
222+
unconstrained = chol*row_norms[:, pt.newaxis]
223+
224+
return unconstrained[self.tril_r_idxs, self.tril_c_idxs]
225+
226+
def backward(self, value, *inputs, foo=False):
227+
"""
228+
Convert unconstrained real numbers to the off-diagonal elements of the
229+
cholesky decomposition of a correlation matrix.
230+
"""
231+
# The diagonals of this matrix are 1, but these ones are just used for
232+
# computing a denominator. The diagonals of the cholesky factor are not
233+
# returned, but they are not ones.
234+
chol_pre_norm = pt.set_subtensor(
235+
pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs],
236+
value
237+
)
238+
239+
# derivative of pt.linalg.norm ended up complex, which caused errors
240+
# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX")
241+
242+
row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5)
243+
chol = chol_pre_norm / row_norm[:, pt.newaxis]
244+
245+
# Undo the cholesky decomposition
246+
corr = pt.matmul(chol, chol.T)
247+
248+
# We want the upper triangular indices here.
249+
return corr[self.triu_r_idxs, self.triu_c_idxs]
250+
251+
141252
class CholeskyCovPacked(Transform):
142253
"""
143254
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the

0 commit comments

Comments
 (0)