Skip to content

Commit 9d008c5

Browse files
committed
Merge branch 'localchanges' into gp-module
2 parents 20e06b1 + 1a9e68f commit 9d008c5

File tree

2 files changed

+190
-0
lines changed

2 files changed

+190
-0
lines changed

pymc3/gp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from . import mean
33
from . import util
44
from .gp import Latent, Marginal, MarginalSparse, TP
5+
from .grid import Grid2DLatent

pymc3/gp/grid.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import theano.tensor as tt
2+
import pymc3 as pm
3+
from pymc3.gp.gp import Base
4+
from pymc3.gp.cov import Covariance
5+
from pymc3.gp.mean import Constant
6+
from pymc3.gp.util import (conditioned_vars,
7+
infer_shape, stabilize, cholesky, solve, solve_lower, solve_upper)
8+
9+
10+
__all__ = ["Grid2DLatent"]
11+
12+
def vec(A):
13+
return tt.flatten(tt.transpose(A))[:, None]
14+
15+
16+
def devec(x, r, c):
17+
return x.reshape((c,r)).T
18+
#return tt.transpose(tt.reshape(x, (r, c)))
19+
20+
21+
def grid_to_full(X1, X2, n1, n2):
22+
n1 = infer_shape(X1)
23+
n2 = infer_shape(X2)
24+
X1 = tt.as_tensor_variable(X1)
25+
X2 = tt.as_tensor_variable(X2)
26+
xx1 = tt.repeat(X1, n2, 0)
27+
xx2 = tt.tile(X2, (n1, 1))
28+
return tt.concatenate((xx1, xx2), 1)
29+
30+
31+
def kronprod(A1, A2, v, veced=True):
32+
if veced:
33+
v = devec(v, A2.shape[0], A1.shape[0])
34+
tmp = tt.dot(v, A1)
35+
tmp = tt.dot(A2, tmp)
36+
return tt.dot(A2, tt.dot(v, A1))
37+
38+
39+
def kronsolve(A1, A2, v, veced=True, chol=False):
40+
""" if chol is true, assume A1, A2 chol factors """
41+
if veced:
42+
v = devec(v, A2.shape[0], A1.shape[0])
43+
if chol:
44+
return vec(solve(A2, tt.transpose(solve_upper(tt.transpose(A1), tt.transpose(v)))))
45+
else:
46+
return vec(solve(A2, tt.transpose(solve(tt.transpose(A1), tt.transpose(v)))))
47+
48+
49+
def make_gridcov_func(self, cov_func1, cov_func2):
50+
ndim1 = cov_func1.input_dim
51+
def gridcov_func(X, Xnew=None):
52+
if Xnew is not None:
53+
return cov_func1(X[:,:ndim1], Xnew[:,ndim1:]) * cov_func2(X[:,:ndim1], Xnew[:,ndim2:])
54+
else:
55+
return cov_func1(X[:,:ndim1]) * cov_func2(X[:,:ndim1])
56+
return gridcov_func
57+
58+
59+
class Grid(Base):
60+
def __init__(self, cov_funcs, mean_func):
61+
self.cov_func1, self.cov_func2 = cov_funcs
62+
self.mean_func = mean_func
63+
# K1 is 5x5
64+
# K2 is 3x3
65+
# Y is 3x5
66+
67+
@conditioned_vars(["X", "n", "f"])
68+
class Grid2DLatent(Grid):
69+
def _build_prior(self, name, X, n, reparameterize=True):
70+
mu = 0.0
71+
L1 = cholesky(stabilize(self.cov_func1(X[0])))
72+
L2 = cholesky(stabilize(self.cov_func2(X[1])))
73+
if reparameterize:
74+
v = pm.Normal(name + "_rotated_", mu=0.0, sd=1.0, shape=n[0] * n[1])
75+
f = pm.Deterministic(name, devec(kronprod(L1, L2, v[:,None]), n[0], n[1]))
76+
else:
77+
raise NotImplementedError
78+
return f
79+
80+
def prior(self, name, X, n=None, reparameterize=True):
81+
if n is None:
82+
n = (None, None)
83+
if len(X) != 2 or len(n) != 2:
84+
raise ValueError("2d implemented")
85+
n = (infer_shape(X[0], n[0]), infer_shape(X[1], n[1]))
86+
f = self._build_prior(name, X, n, reparameterize)
87+
self.X = (tt.as_tensor_variable(X[0]), tt.as_tensor_variable(X[1]))
88+
self.n = n
89+
self.f = f
90+
return f
91+
92+
def _get_given_vals(self, **given):
93+
if 'gp' in given:
94+
cov_total = given['gp'].cov_func
95+
mean_total = given['gp'].mean_func
96+
else:
97+
cov_total = self.cov_func
98+
mean_total = self.mean_func
99+
if all(val in given for val in ['X', 'n', 'f']):
100+
X, n, f = given['X'], given['n'], given['f']
101+
if not isinstance(X, tuple) and not isinstance(n, tuple):
102+
raise ValueError("must provide tuple with each element for dimension")
103+
else:
104+
X, n, f = self.X, self.n, self.f
105+
return X, n, f, cov_total, mean_total
106+
107+
def conditional(self, name, Xnew, n_points=None, given=None):
108+
givens = self._get_given_vals(**given)
109+
mu, cov, n_points = self._build_conditional(Xnew, n_points, *givens)
110+
chol = cholesky(stabilize(cov))
111+
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
112+
113+
def _build_conditional(self, Xnew, n_points, X, n, f, cov_total, mean_total):
114+
Kxx1 = cov_total1(X1)
115+
Kxx2 = cov_total2(X2)
116+
L1 = cholesky(stabilize(Kxx1))
117+
L2 = cholesky(stabilize(Kxx2))
118+
if isinstance(Xnew, tuple) and isinstance(n_points, tuple):
119+
# if X1new and X2new given sep, Xnew comes in as a tuple
120+
Kxs = tt.slinalg.kron(self.cov_func1(self.X[0], Xnew[0]),
121+
self.cov_func2(self.X[1], Xnew[1]))
122+
Kss = tt.slinalg.kron(self.cov_func1(Xnew[0]),
123+
self.cov_func2(Xnew[1]))
124+
Xnew = grid_to_full(Xnew[0], Xnew[1], n_points[0], n_points[1])
125+
n_points = np.prod(n_points)
126+
else:
127+
# predict given full
128+
gridcov_func = make_gridcov_func(self.cov_func1, self.cov_func2)
129+
X = grid_to_full(self.X[0], self.X[1], self.n[0], self.n[1])
130+
Kxs = gridcov_func(X, Xnew)
131+
Kss = gridcov_func(Xnew)
132+
n_points = infer_shape(Xnew, n_points)
133+
A = kronsolve(L1, L2, Kxs, veced=False, chol=True)
134+
v = kronsolve(L1, L2, f, veced=False, chol=True) # f - mean_total(X)
135+
#mu = self.mean_func(Xnew) + tt.dot(tt.transpose(A), v)
136+
mu = tt.dot(tt.transpose(A), v)
137+
cov = Kss - tt.dot(tt.transpose(A), A)
138+
return mu, cov, n_points
139+
140+
141+
@conditioned_vars(["X", "n", "f"])
142+
class Grid2DMarginal(Grid):
143+
144+
def _build_marginal_likelihood_logp(self, X, sigma, n):
145+
mu = 0.0
146+
K1 = stabilize(self.cov_func1(X[0]))
147+
K2 = stabilize(self.cov_func2(X[1]))
148+
if sigma is None: # use cholesky
149+
L1 = cholesky(stabilize(self.cov_func1(X[0])))
150+
L2 = cholesky(stabilize(self.cov_func2(X[1])))
151+
const = n[0] * n[1] * tt.log(2 * np.pi)
152+
logdet = (n[1] * 2.0 * tt.sum(tt.log(tt.diag(L1))) +
153+
n[0] * 2.0 * tt.sum(tt.log(tt.diag(L2))))
154+
tmp = kronsolve(L1, L2, Y, veced=False, chol=True)
155+
quad = tt.sum(tt.square(tmp))
156+
return -0.5 * (logdet + quad + const)
157+
else: # use eigh
158+
S1, Q1 = tt.nlinalg.eigh(K1)
159+
S2, Q2 = tt.nlinalg.eigh(K2)
160+
W = np.kron(S2, S1) + tt.square(sigma)
161+
Qinvy = kronsolve(Q1, Q2, Y, veced=False, chol=False)
162+
const = n[0] * n[1] * tt.log(2.0 * np.pi)
163+
logdet = tt.sum(tt.log(W))
164+
quad = tt.sum(tt.square(tt.sqrt(1.0 / W) * Qinvy))
165+
return -0.5 * (logdet + quad + const)
166+
167+
168+
def marginal_likelihood(self, name, X, sigma=None, n=None):
169+
if n is None:
170+
n = (None, None)
171+
if len(X) != 2 or len(n) != 2:
172+
raise ValueError("2d implemented")
173+
n = (infer_shape(X[0], n[0]), infer_shape(X[1], n[1]))
174+
logp = lambda Y: self._build_marginal_likelihood(X, Y, sigma, n)
175+
self.X = (tt.as_tensor_variable(X[0]), tt.as_tensor_variable(X[1]))
176+
self.n = n
177+
return f
178+
179+
180+
181+
182+
183+
184+
185+
186+
187+
188+
189+

0 commit comments

Comments
 (0)