|
17 | 17 |
|
18 | 18 | import numpy as np
|
19 | 19 | import pytensor.tensor as pt
|
| 20 | +import pytensor |
20 | 21 |
|
21 | 22 |
|
22 | 23 | # ignore mypy error because it somehow considers that
|
|
45 | 46 | "log",
|
46 | 47 | "sum_to_1",
|
47 | 48 | "circular",
|
| 49 | + "CholeskyCorr", |
48 | 50 | "CholeskyCovPacked",
|
49 | 51 | "Chain",
|
50 | 52 | "ZeroSumTransform",
|
@@ -138,6 +140,115 @@ def log_jac_det(self, value, *inputs):
|
138 | 140 | return pt.sum(y, axis=-1)
|
139 | 141 |
|
140 | 142 |
|
| 143 | +class CholeskyCorr(Transform): |
| 144 | + """ |
| 145 | + Transforms the off-diagonal elements of a correlation matrix to |
| 146 | + unconstrained real numbers. |
| 147 | +
|
| 148 | + Note: This is not particular to the LKJ distribution - it is only a |
| 149 | + transform to help generate cholesky decompositions for random valid |
| 150 | + correlation matrices. |
| 151 | +
|
| 152 | + Ported from here: https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31 |
| 153 | +
|
| 154 | + The backward side of this transformation is the off-diagonal upper |
| 155 | + triangular elements of a correlation matrix, specified in row major order. |
| 156 | + """ |
| 157 | + |
| 158 | + name = "cholesky-corr" |
| 159 | + |
| 160 | + def __init__(self, n): |
| 161 | + """ |
| 162 | +
|
| 163 | + Parameters |
| 164 | + ---------- |
| 165 | + n: int |
| 166 | + Size of correlation matrix |
| 167 | + """ |
| 168 | + self.n = n |
| 169 | + self.m = int(n*(n-1)/2) # number of off-diagonal elements |
| 170 | + self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() |
| 171 | + self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() |
| 172 | + |
| 173 | + def _generate_tril_indices(self): |
| 174 | + row_indices, col_indices = np.tril_indices(self.n, -1) |
| 175 | + return ( |
| 176 | + pytensor.shared(row_indices), |
| 177 | + pytensor.shared(col_indices) |
| 178 | + ) |
| 179 | + |
| 180 | + def _generate_triu_indices(self): |
| 181 | + row_indices, col_indices = np.triu_indices(self.n, 1) |
| 182 | + return ( |
| 183 | + pytensor.shared(row_indices), |
| 184 | + pytensor.shared(col_indices) |
| 185 | + ) |
| 186 | + |
| 187 | + def _jacobian(self, value, *inputs): |
| 188 | + return pt.jacobian( |
| 189 | + self.backward(value), |
| 190 | + wrt=value |
| 191 | + ) |
| 192 | + |
| 193 | + def log_jac_det(self, value, *inputs): |
| 194 | + """ |
| 195 | + Compute log of the determinant of the jacobian. |
| 196 | +
|
| 197 | + There are no clever tricks here - we literally compute the jacobian |
| 198 | + then compute its determinant then take log. |
| 199 | + """ |
| 200 | + jac = self._jacobian(value) |
| 201 | + return pt.log(pt.linalg.det(jac)) |
| 202 | + |
| 203 | + def forward(self, value, *inputs): |
| 204 | + """ |
| 205 | + Convert the off-diagonal elements of a cholesky decomposition of a |
| 206 | + correlation matrix to unconstrained real numbers. |
| 207 | + """ |
| 208 | + # The correlation matrix is specified via its upper triangular elements |
| 209 | + corr = pt.set_subtensor( |
| 210 | + pt.zeros((self.n, self.n))[self.triu_r_idxs, self.triu_c_idxs], |
| 211 | + value |
| 212 | + ) |
| 213 | + corr = corr + corr.T + pt.eye(self.n) |
| 214 | + |
| 215 | + chol = pt.linalg.cholesky(corr) |
| 216 | + |
| 217 | + # Are the diagonals always guaranteed to be positive? |
| 218 | + # I don't know, so we'll use abs |
| 219 | + row_norms = 1/pt.abs(pt.diag(chol)) |
| 220 | + |
| 221 | + # Multiply by the row norms to undo the normalization |
| 222 | + unconstrained = chol*row_norms[:, pt.newaxis] |
| 223 | + |
| 224 | + return unconstrained[self.tril_r_idxs, self.tril_c_idxs] |
| 225 | + |
| 226 | + def backward(self, value, *inputs, foo=False): |
| 227 | + """ |
| 228 | + Convert unconstrained real numbers to the off-diagonal elements of the |
| 229 | + cholesky decomposition of a correlation matrix. |
| 230 | + """ |
| 231 | + # The diagonals of this matrix are 1, but these ones are just used for |
| 232 | + # computing a denominator. The diagonals of the cholesky factor are not |
| 233 | + # returned, but they are not ones. |
| 234 | + chol_pre_norm = pt.set_subtensor( |
| 235 | + pt.eye(self.n).astype("floatX")[self.tril_r_idxs, self.tril_c_idxs], |
| 236 | + value |
| 237 | + ) |
| 238 | + |
| 239 | + # derivative of pt.linalg.norm ended up complex, which caused errors |
| 240 | +# row_norm = pt.abs(pt.linalg.norm(chol_pre_norm, axis=1))[:, pt.newaxis].astype("floatX") |
| 241 | + |
| 242 | + row_norm = pt.pow(pt.abs(pt.pow(chol_pre_norm, 2).sum(1)), 0.5) |
| 243 | + chol = chol_pre_norm / row_norm[:, pt.newaxis] |
| 244 | + |
| 245 | + # Undo the cholesky decomposition |
| 246 | + corr = pt.matmul(chol, chol.T) |
| 247 | + |
| 248 | + # We want the upper triangular indices here. |
| 249 | + return corr[self.triu_r_idxs, self.triu_c_idxs] |
| 250 | + |
| 251 | + |
141 | 252 | class CholeskyCovPacked(Transform):
|
142 | 253 | """
|
143 | 254 | Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
|
|
0 commit comments