Skip to content

Commit 5e329a5

Browse files
author
Junpeng Lao
committed
Use more stable LKJCholeskyCov
1 parent 74c5b2e commit 5e329a5

File tree

1 file changed

+23
-28
lines changed

1 file changed

+23
-28
lines changed

pymc3/examples/LKJ_correlation.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,54 @@
11
import theano.tensor as tt
22
import numpy as np
33
from numpy.random import multivariate_normal
4-
54
import pymc3 as pm
65

76
# Generate some multivariate normal data:
87
n_obs = 1000
98

109
# Mean values:
11-
mu = np.linspace(0, 2, num=4)
12-
n_var = len(mu)
10+
mu_r = np.linspace(0, 2, num=4)
11+
n_var = len(mu_r)
1312

1413
# Standard deviations:
1514
stds = np.ones(4) / 2.0
1615

1716
# Correlation matrix of 4 variables:
18-
corr = np.array([[1., 0.75, 0., 0.15],
19-
[0.75, 1., -0.06, 0.19],
20-
[0., -0.06, 1., -0.04],
21-
[0.15, 0.19, -0.04, 1.]])
22-
cov_matrix = np.diag(stds).dot(corr.dot(np.diag(stds)))
23-
24-
dataset = multivariate_normal(mu, cov_matrix, size=n_obs)
25-
17+
corr_r = np.array([[1., 0.75, 0., 0.15],
18+
[0.75, 1., -0.06, 0.19],
19+
[0., -0.06, 1., -0.04],
20+
[0.15, 0.19, -0.04, 1.]])
21+
cov_matrix = np.diag(stds).dot(corr_r.dot(np.diag(stds)))
2622

27-
# In order to convert the upper triangular correlation values to a complete
28-
# correlation matrix, we need to construct an index matrix:
29-
n_elem = int(n_var * (n_var - 1) / 2)
30-
tri_index = np.zeros([n_var, n_var], dtype=int)
31-
tri_index[np.triu_indices(n_var, k=1)] = np.arange(n_elem)
32-
tri_index[np.triu_indices(n_var, k=1)[::-1]] = np.arange(n_elem)
23+
dataset = multivariate_normal(mu_r, cov_matrix, size=n_obs)
3324

3425
with pm.Model() as model:
3526

3627
mu = pm.Normal('mu', mu=0, sd=1, shape=n_var)
3728

38-
# We can specify separate priors for sigma and the correlation matrix:
39-
sigma = pm.Uniform('sigma', shape=n_var)
40-
corr_triangle = pm.LKJCorr('corr', n=1, p=n_var)
41-
corr_matrix = corr_triangle[tri_index]
42-
corr_matrix = tt.fill_diagonal(corr_matrix, 1)
29+
# Note that we access the distribution for the standard
30+
# deviations, and do not create a new random variable.
31+
sd_dist = pm.HalfCauchy.dist(beta=2.5)
32+
packed_chol = pm.LKJCholeskyCov('chol_cov', n=n_var, eta=1, sd_dist=sd_dist)
33+
# compute the covariance matrix
34+
chol = pm.expand_packed_triangular(n_var, packed_chol, lower=True)
35+
cov = tt.dot(chol, chol.T)
4336

44-
cov_matrix = tt.diag(sigma).dot(corr_matrix.dot(tt.diag(sigma)))
37+
# Extract the standard deviations etc
38+
sd = pm.Deterministic('sd', tt.sqrt(tt.diag(cov)))
39+
corr = tt.diag(sd**-1).dot(cov.dot(tt.diag(sd**-1)))
40+
r = pm.Deterministic('r', corr[np.triu_indices(n_var, k=1)])
4541

46-
like = pm.MvNormal('likelihood', mu=mu, cov=cov_matrix, observed=dataset)
42+
like = pm.MvNormal('likelihood', mu=mu, chol=chol, observed=dataset)
4743

4844

4945
def run(n=1000):
5046
if n == "short":
5147
n = 50
5248
with model:
53-
start = pm.find_MAP()
54-
step = pm.NUTS(scaling=start)
55-
trace = pm.sample(n, step=step, start=start)
56-
return trace
49+
trace = pm.sample(n)
50+
pm.traceplot(trace, varnames=['mu', 'r'],
51+
lines={'mu': mu_r, 'r': corr_r[np.triu_indices(n_var, k=1)]})
5752

5853
if __name__ == '__main__':
5954
run()

0 commit comments

Comments
 (0)