Skip to content

Commit f845575

Browse files
author
Junpeng Lao
authored
Implement random method for LKJCorr (#2443)
* [WIP] implement random method for LKJCorr using the algorithm in LKJ 2009(vine method based on a C-vine) * pep8 fix * import random method * Revert "import random method" This reverts commit 42e620e. * restrict dimension n to n > 1, fix random sample for n = 2 * unify parameter naming between LKJCorr and LKJCholeskyCov * improved random method * Bug fix in shape * Add test * fix test * pep8 clean up, fixed test warning * improved random method * forked R code for random method generated Corr matrix is now positive definite.
1 parent 41d70b8 commit f845575

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

pymc3/distributions/multivariate.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ class LKJCholeskyCov(Continuous):
792792
Parameters
793793
----------
794794
n : int
795-
Dimension of the covariance matrix (n > 0).
795+
Dimension of the covariance matrix (n > 1).
796796
eta : float
797797
The shape parameter (eta > 0) of the LKJ distribution. eta = 1
798798
implies a uniform distribution of the correlation matrices;
@@ -962,7 +962,7 @@ class LKJCorr(Continuous):
962962
Parameters
963963
----------
964964
n : int
965-
Dimension of the covariance matrix (n > 0).
965+
Dimension of the covariance matrix (n > 1).
966966
eta : float
967967
The shape parameter (eta > 0) of the LKJ distribution. eta = 1
968968
implies a uniform distribution of the correlation matrices;
@@ -1006,20 +1006,48 @@ def __init__(self, eta=None, n=None, p=None, transform='interval', *args, **kwar
10061006
raise ValueError('Invalid parameter: please use eta as the shape parameter and '
10071007
'n as the dimension parameter.')
10081008

1009-
n_elem = int(n * (n - 1) / 2)
1010-
self.mean = np.zeros(n_elem, dtype=theano.config.floatX)
1009+
shape = n * (n - 1) // 2
1010+
self.mean = floatX(np.zeros(shape))
10111011

10121012
if transform == 'interval':
10131013
transform = transforms.interval(-1, 1)
10141014

1015-
super(LKJCorr, self).__init__(shape=n_elem, transform=transform,
1015+
super(LKJCorr, self).__init__(shape=shape, transform=transform,
10161016
*args, **kwargs)
10171017
warnings.warn('Parameters in LKJCorr have been rename: shape parameter n -> eta '
10181018
'dimension parameter p -> n. Please double check your initialization.',
10191019
DeprecationWarning)
10201020
self.tri_index = np.zeros([n, n], dtype='int32')
1021-
self.tri_index[np.triu_indices(n, k=1)] = np.arange(n_elem)
1022-
self.tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(n_elem)
1021+
self.tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
1022+
self.tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
1023+
1024+
def _random(self, n, eta, size=None):
1025+
size = size if isinstance(size, tuple) else (size,)
1026+
# original implementation in R see:
1027+
# https://github.com/rmcelreath/rethinking/blob/master/R/distributions.r
1028+
beta = eta - 1 + n/2
1029+
r12 = 2 * stats.beta.rvs(a=beta, b=beta, size=size) - 1
1030+
P = np.eye(n)[:, :, np.newaxis] * np.ones(size)
1031+
P[0, 1] = r12
1032+
P[1, 1] = np.sqrt(1 - r12**2)
1033+
if n > 2:
1034+
for m in range(1, n-1):
1035+
beta -= 0.5
1036+
y = stats.beta.rvs(a=(m+1) / 2., b=beta, size=size)
1037+
z = stats.norm.rvs(loc=0, scale=1, size=(m+1, ) + size)
1038+
z = z / np.sqrt(np.einsum('ij,ij->j', z, z))
1039+
P[0:m+1, m+1] = np.sqrt(y) * z
1040+
P[m+1, m+1] = np.sqrt(1 - y)
1041+
Pt = np.transpose(P, (2, 0 ,1))
1042+
C = np.einsum('...ji,...jk->...ik', Pt, Pt)
1043+
return C.transpose((1, 2, 0))[np.triu_indices(n, k=1)].T
1044+
1045+
def random(self, point=None, size=None):
1046+
n, eta = draw_values([self.n, self.eta], point=point)
1047+
size= 1 if size is None else size
1048+
samples = generate_samples(self._random, n, eta,
1049+
broadcast_shape=(size,))
1050+
return samples
10231051

10241052
def logp(self, x):
10251053
n = self.n

pymc3/tests/test_distributions_random.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,25 @@ def test_wishart(self):
621621
# st.wishart(V, df=n, size=size))
622622
pass
623623

624-
@pytest.mark.skip('LKJ random sampling not implemented yet.')
625624
def test_lkj(self):
626-
# TODO: generate random numbers.
627-
pass
625+
for n in [2, 10, 50]:
626+
#pylint: disable=cell-var-from-loop
627+
shape = n*(n-1)//2
628+
629+
def ref_rand(size, eta):
630+
beta = eta - 1 + n/2
631+
return (st.beta.rvs(size=(size, shape), a=beta, b=beta)-.5)*2
632+
633+
class TestedLKJCorr (pm.LKJCorr):
634+
635+
def __init__(self, **kwargs):
636+
kwargs.pop('shape', None)
637+
super(TestedLKJCorr, self).__init__(
638+
n=n,
639+
**kwargs
640+
)
641+
642+
pymc3_random(TestedLKJCorr,
643+
{'eta': Domain([1., 10., 100.])},
644+
size=10000//n,
645+
ref_rand=ref_rand)

0 commit comments

Comments
 (0)