Skip to content

Commit 8d279fc

Browse files
Basic implementation
1 parent ce5f2a2 commit 8d279fc

File tree

2 files changed

+67
-14
lines changed

2 files changed

+67
-14
lines changed

pymc/distributions/multivariate.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,7 +1510,7 @@ def helper_deterministics(cls, n, packed_chol):
15101510

15111511
class LKJCorrRV(RandomVariable):
15121512
name = "lkjcorr"
1513-
signature = "(),()->(n)"
1513+
signature = "(),()->(n,n)"
15141514
dtype = "floatX"
15151515
_print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}")
15161516

@@ -1527,8 +1527,8 @@ def make_node(self, rng, size, n, eta):
15271527

15281528
def _supp_shape_from_params(self, dist_params, **kwargs):
15291529
n = dist_params[0].squeeze()
1530-
dist_shape = ((n * (n - 1)) // 2,)
1531-
return dist_shape
1530+
# dist_shape = ((n * (n - 1)) // 2,)
1531+
return (n, n)
15321532

15331533
@classmethod
15341534
def rng_fn(cls, rng, n, eta, size):
@@ -1609,23 +1609,26 @@ def logp(value, n, eta):
16091609
-------
16101610
TensorVariable
16111611
"""
1612-
if value.ndim > 1:
1613-
raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)")
1614-
1615-
# TODO: PyTensor does not have a `triu_indices`, so we can only work with constant
1616-
# n (or else find a different expression)
1612+
# if value.ndim > 1:
1613+
# raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)")
1614+
#
16171615
try:
16181616
n = int(get_underlying_scalar_constant_value(n))
16191617
except NotScalarConstantError:
16201618
raise NotImplementedError("logp only implemented for constant `n`")
16211619

1622-
shape = n * (n - 1) // 2
1623-
tri_index = np.zeros((n, n), dtype="int32")
1624-
tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
1625-
tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
1620+
# shape = n * (n - 1) // 2
1621+
# tri_index = np.zeros((n, n), dtype="int32")
1622+
# tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
1623+
# tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
1624+
1625+
# value = pt.take(value, tri_index)
1626+
# value = pt.fill_diagonal(value, 1)
16261627

1627-
value = pt.take(value, tri_index)
1628-
value = pt.fill_diagonal(value, 1)
1628+
# print(n, type(n))
1629+
# print(value.type.shape)
1630+
# value = value @ value.T
1631+
# print(value.type.shape)
16291632

16301633
# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
16311634
try:

pymc/distributions/transforms.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,56 @@ def log_jac_det(self, value, *inputs):
164164
return pt.sum(value[..., self.diag_idxs], axis=-1)
165165

166166

167+
class CholeskyCorr(Transform):
168+
"""Get a Cholesky Corr from a packed vector."""
169+
170+
name = "cholesky-corr-packed"
171+
172+
def __init__(self, n):
173+
"""Create a CholeskyCorrPack object.
174+
175+
Parameters
176+
----------
177+
n: int
178+
Number of diagonal entries in the LKJCholeskyCov distribution
179+
"""
180+
self.n = n
181+
182+
def _compute_L_and_logdet(self, value, *inputs):
183+
n = self.n
184+
counter = 0
185+
L = pt.eye(n)
186+
log_det = 0
187+
188+
for i in range(1, n):
189+
y_star = value[counter : counter + i]
190+
dsy = y_star.dot(y_star)
191+
alpha_r = 1 / (dsy + 1)
192+
gamma = pt.sqrt(dsy + 2) * alpha_r
193+
194+
x = pt.join(0, gamma * y_star, pt.atleast_1d(alpha_r))
195+
L = L[i, : i + 1].set(x)
196+
log_det += pt.log(2) + 0.5 * (i - 2) * pt.log(dsy + 2) - i * pt.log(1 + dsy)
197+
198+
counter += i
199+
200+
# Return whole matrix? Or just lower triangle?
201+
return L, log_det
202+
203+
def backward(self, value, *inputs):
204+
L, _ = self._compute_L_and_logdet(value, *inputs)
205+
return L
206+
207+
def forward(self, value, *inputs):
208+
# TODO: This is a placeholder
209+
n = self.n
210+
return pt.as_tensor_variable(np.random.normal(size=(n,)))
211+
212+
def log_jac_det(self, value, *inputs):
213+
_, log_det = self._compute_L_and_logdet(value, *inputs)
214+
return log_det
215+
216+
167217
Chain = ChainedTransform
168218

169219
simplex = SimplexTransform()

0 commit comments

Comments
 (0)