Skip to content

Commit 6f5ceb6

Browse files
committed
Merge branch 'localchanges' into gp-module
2 parents 5197d5c + 74d14d8 commit 6f5ceb6

File tree

5 files changed

+1159
-698
lines changed

5 files changed

+1159
-698
lines changed

docs/source/notebooks/GP-Latent.ipynb

Lines changed: 114 additions & 100 deletions
Large diffs are not rendered by default.

docs/source/notebooks/GP-Marginal.ipynb

Lines changed: 155 additions & 69 deletions
Large diffs are not rendered by default.

docs/source/notebooks/GP-MaunaLoa.ipynb

Lines changed: 764 additions & 408 deletions
Large diffs are not rendered by default.

pymc3/gp/gp.py

Lines changed: 123 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import pymc3 as pm
88
from pymc3.gp.cov import Covariance
9+
from pymc3.gp.mean import Constant
910
from pymc3.gp.util import conditioned_vars
11+
from pymc3.distributions import draw_values
1012

1113
__all__ = ['Latent', 'Marginal', 'TP', 'MarginalSparse']
1214

@@ -20,7 +22,7 @@ def stabilize(K):
2022
return K + 1e-6 * tt.identity_like(K)
2123

2224

23-
class GPBase(object):
25+
class Base(object):
2426
"""
2527
Base class
2628
"""
@@ -36,44 +38,13 @@ def __init__(self, mean_func=None, cov_func=None):
3638
self.mean_func = mean_func
3739
self.cov_func = cov_func
3840

39-
@property
40-
def cov_total(self):
41-
total = getattr(self, "_cov_total", None)
42-
if total is None:
43-
return self.cov_func
44-
else:
45-
return total
46-
47-
@cov_total.setter
48-
def cov_total(self, new_cov_total):
49-
self._cov_total = new_cov_total
50-
51-
@property
52-
def mean_total(self):
53-
total = getattr(self, "_mean_total", None)
54-
if total is None:
55-
return self.mean_func
56-
else:
57-
return total
58-
59-
@mean_total.setter
60-
def mean_total(self, new_mean_total):
61-
self._mean_total = new_mean_total
62-
6341
def __add__(self, other):
6442
same_attrs = set(self.__dict__.keys()) == set(other.__dict__.keys())
6543
if not isinstance(self, type(other)) and not same_attrs:
6644
raise ValueError("cant add different GP types")
67-
68-
# set cov_func and mean_func of new GP
69-
cov_total = self.cov_func + other.cov_func
7045
mean_total = self.mean_func + other.mean_func
71-
72-
# update self and other mean and cov totals
73-
self.cov_total, self.mean_total = (cov_total, mean_total)
74-
other.cov_total, other.mean_total = (cov_total, mean_total)
75-
new_gp = self.__class__(mean_total, cov_total)
76-
return new_gp
46+
cov_total = self.cov_func + other.cov_func
47+
return self.__class__(mean_total, cov_total)
7748

7849
def prior(self, name, X, *args, **kwargs):
7950
raise NotImplementedError
@@ -84,9 +55,11 @@ def marginal_likelihood(self, name, X, *args, **kwargs):
8455
def conditional(self, name, n_points, Xnew, *args, **kwargs):
8556
raise NotImplementedError
8657

58+
def predict(self, Xnew, point=None, given=None, diag=False):
59+
raise NotImplementedError
8760

8861
@conditioned_vars(["X", "f"])
89-
class Latent(GPBase):
62+
class Latent(Base):
9063
""" Where the GP f isnt integrated out, and is sampled explicitly
9164
"""
9265
def __init__(self, mean_func=None, cov_func=None):
@@ -100,30 +73,35 @@ def _build_prior(self, name, n_points, X, reparameterize=True):
10073
f = pm.Deterministic(name, mu + tt.dot(chol, v))
10174
else:
10275
f = pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
103-
self.X = X
104-
self.f = f
10576
return f
10677

10778
def prior(self, name, n_points, X, reparameterize=True):
10879
f = self._build_prior(name, n_points, X, reparameterize)
80+
self.X = X
81+
self.f = f
10982
return f
11083

111-
def _build_conditional(self, Xnew, X, f):
112-
Kxx = self.cov_total(X)
84+
def _get_cond_vals(self, other=None):
85+
if other is None:
86+
return self.X, self.f, self.cov_func, self.mean_func,
87+
else:
88+
return other.X, other.f, other.cov_func, other.mean_func
89+
90+
def _build_conditional(self, Xnew, X, f, cov_total, mean_total):
91+
Kxx = cov_total(X)
11392
Kxs = self.cov_func(X, Xnew)
114-
Kss = self.cov_func(Xnew)
11593
L = cholesky(stabilize(Kxx))
11694
A = solve_lower(L, Kxs)
117-
cov = Kss - tt.dot(tt.transpose(A), A)
118-
chol = cholesky(stabilize(cov))
119-
v = solve_lower(L, f - self.mean_total(X))
95+
v = solve_lower(L, f - mean_total(X))
12096
mu = self.mean_func(Xnew) + tt.dot(tt.transpose(A), v)
121-
return mu, chol
97+
Kss = self.cov_func(Xnew)
98+
cov = Kss - tt.dot(tt.transpose(A), A)
99+
return mu, cov
122100

123-
def conditional(self, name, n_points, Xnew, X=None, f=None):
124-
if X is None: X = self.X
125-
if f is None: f = self.f
126-
mu, chol = self._build_conditional(Xnew, X, f)
101+
def conditional(self, name, n_points, Xnew, given=None):
102+
X, f, cov_total, mean_total = self._get_cond_vals(given)
103+
mu, cov = self._build_conditional(Xnew, X, f, cov_total, mean_total)
104+
chol = cholesky(stabilize(cov))
127105
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
128106

129107

@@ -141,58 +119,54 @@ def __init__(self, mean_func=None, cov_func=None, nu=None):
141119
def __add__(self, other):
142120
raise ValueError("Student T processes aren't additive")
143121

144-
def _build_prior(self, name, n_points, X, nu):
122+
def _build_prior(self, name, n_points, X, reparameterize=True):
145123
mu = self.mean_func(X)
146124
chol = cholesky(stabilize(self.cov_func(X)))
125+
if reparameterize:
126+
chi2 = pm.ChiSquared("chi2_", self.nu)
127+
v = pm.Normal(name + "_rotated_", mu=0.0, sd=1.0, shape=n_points)
128+
f = pm.Deterministic(name, (tt.sqrt(self.nu) / chi2) * (mu + tt.dot(chol, v)))
129+
else:
130+
f = pm.MvStudentT(name, nu=self.nu, mu=mu, chol=chol, shape=n_points)
131+
return f
147132

148-
chi2 = pm.ChiSquared("chi2_", nu)
149-
v = pm.Normal(name + "_rotated_", mu=0.0, sd=1.0, shape=n_points)
150-
f = pm.Deterministic(name, (tt.sqrt(nu) / chi2) * (mu + tt.dot(chol, v)))
151-
133+
def prior(self, name, n_points, X, reparameterize=True):
134+
f = self._build_prior(name, n_points, X, reparameterize)
152135
self.X = X
153136
self.f = f
154-
self.nu = nu
155137
return f
156138

157-
def prior(self, name, n_points, X, nu):
158-
f = self._build_prior(name, n_points, X, nu)
159-
return f
160-
161-
def _build_conditional(self, Xnew, X, f, nu):
162-
Kxx = self.cov_total(X)
139+
def _build_conditional(self, Xnew, X, f):
140+
Kxx = self.cov_func(X)
163141
Kxs = self.cov_func(X, Xnew)
164142
Kss = self.cov_func(Xnew)
165143
L = cholesky(stabilize(Kxx))
166144
A = solve_lower(L, Kxs)
167145
cov = Kss - tt.dot(tt.transpose(A), A)
168-
169-
v = solve_lower(L, f - self.mean_total(X))
146+
v = solve_lower(L, f - self.mean_func(X))
170147
mu = self.mean_func(Xnew) + tt.dot(tt.transpose(A), v)
171-
172148
beta = tt.dot(v, v)
173-
nu2 = nu + X.shape[0]
174-
175-
covT = (nu + beta - 2)/(nu2 - 2) * cov
149+
nu2 = self.nu + X.shape[0]
150+
covT = (self.nu + beta - 2)/(nu2 - 2) * cov
151+
return nu2, mu, covT
152+
153+
def conditional(self, name, n_points, Xnew):
154+
X = self.X
155+
f = self.f
156+
nu2, mu, covT = self._build_conditional(Xnew, X, f)
176157
chol = cholesky(stabilize(covT))
177-
return nu2, mu, chol
178-
179-
def conditional(self, name, n_points, Xnew, X=None, f=None, nu=None):
180-
if X is None: X = self.X
181-
if f is None: f = self.f
182-
if nu is None: nu = self.nu
183-
nu2, mu, chol = self._build_conditional(Xnew, X, f, nu)
184158
return pm.MvStudentT(name, nu=nu2, mu=mu, chol=chol, shape=n_points)
185159

186160

187161
@conditioned_vars(["X", "y", "noise"])
188-
class Marginal(GPBase):
162+
class Marginal(Base):
189163

190164
def __init__(self, mean_func=None, cov_func=None):
191165
super(Marginal, self).__init__(mean_func, cov_func)
192166

193167
def _build_marginal_likelihood(self, X, noise):
194168
mu = self.mean_func(X)
195-
Kxx = self.cov_total(X)
169+
Kxx = self.cov_func(X)
196170
Knx = noise(X)
197171
cov = Kxx + Knx
198172
chol = cholesky(stabilize(cov))
@@ -212,38 +186,53 @@ def marginal_likelihood(self, name, X, y, noise, n_points=None, is_observed=True
212186
raise ValueError("When `y` is not observed, `n_points` arg is required")
213187
return pm.MvNormal(name, mu=mu, chol=chol, size=n_points)
214188

215-
def _build_conditional(self, Xnew, X, y, noise, pred_noise):
216-
Kxx = self.cov_total(X)
189+
def _build_conditional(self, Xnew, X, y, noise, cov_total, mean_total,
190+
pred_noise, diag=False):
191+
Kxx = cov_total(X)
217192
Kxs = self.cov_func(X, Xnew)
218-
Kss = self.cov_func(Xnew)
219193
Knx = noise(X)
220-
rxx = y - self.mean_total(X)
194+
rxx = y - mean_total(X)
221195
L = cholesky(stabilize(Kxx) + Knx)
222196
A = solve_lower(L, Kxs)
223197
v = solve_lower(L, rxx)
224198
mu = self.mean_func(Xnew) + tt.dot(tt.transpose(A), v)
225-
if pred_noise:
226-
cov = noise(Xnew) + Kss - tt.dot(tt.transpose(A), A)
199+
if diag:
200+
Kss = self.cov_func(Xnew, diag=True)
201+
var = Kss - tt.sum(tt.square(A), 0)
202+
if pred_noise:
203+
var += noise(Xnew, diag=True)
204+
return mu, var
227205
else:
228-
cov = stabilize(Kss) - tt.dot(tt.transpose(A), A)
229-
chol = cholesky(cov)
230-
return mu, chol
231-
232-
def conditional(self, name, n_points, Xnew, X=None, y=None,
233-
noise=None, pred_noise=False):
234-
if X is None: X = self.X
235-
if y is None: y = self.y
236-
if noise is None:
237-
noise = self.noise
206+
Kss = self.cov_func(Xnew)
207+
cov = Kss - tt.dot(tt.transpose(A), A)
208+
if pred_noise:
209+
cov += noise(Xnew)
210+
return mu, stabilize(cov)
211+
212+
def _get_cond_vals(self, other=None):
213+
if other is None:
214+
return self.X, self.y, self.noise, self.cov_func, self.mean_func,
238215
else:
239-
if not isinstance(noise, Covariance):
240-
noise = pm.gp.cov.WhiteNoise(noise)
241-
mu, chol = self._build_conditional(Xnew, X, y, noise, pred_noise)
216+
return other.X, other.y, other.noise, other.cov_func, other.mean_func
217+
218+
def conditional(self, name, n_points, Xnew, given=None, pred_noise=False):
219+
# try to get n_points from X, (via cast to int?), error if cant and n_points is none
220+
X, y, noise, cov_total, mean_total = self._get_cond_vals(given)
221+
mu, cov = self._build_conditional(Xnew, X, y, noise, cov_total, mean_total,
222+
pred_noise, diag=False)
223+
chol = cholesky(cov)
242224
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
243225

226+
def predict(self, Xnew, point=None, given=None, pred_noise=False, diag=False):
227+
X, y, noise, cov_total, mean_total = self._get_cond_vals(given)
228+
mu, cov = self._build_conditional(Xnew, X, y, noise, cov_total, mean_total,
229+
pred_noise, diag)
230+
mu, cov = draw_values([mu, cov], point=point)
231+
return mu, cov
232+
244233

245234
@conditioned_vars(["X", "Xu", "y", "sigma"])
246-
class MarginalSparse(GPBase):
235+
class MarginalSparse(Base):
247236
_available_approx = ["FITC", "VFE", "DTC"]
248237
""" FITC and VFE sparse approximations
249238
"""
@@ -291,7 +280,7 @@ def _build_marginal_likelihood_logp(self, X, Xu, y, sigma):
291280
quadratic = 0.5 * (tt.dot(r, r_l) - tt.dot(c, c))
292281
return -1.0 * (constant + logdet + quadratic + trace)
293282

294-
def marginal_likelihood(self, name, n_points, X, Xu, y, sigma, is_observed=True):
283+
def marginal_likelihood(self, name, X, Xu, y, sigma, n_points=None, is_observed=True):
295284
self.X = X
296285
self.Xu = Xu
297286
self.y = y
@@ -300,49 +289,66 @@ def marginal_likelihood(self, name, n_points, X, Xu, y, sigma, is_observed=True)
300289
if is_observed: # same thing ith n_points here?? check
301290
return pm.DensityDist(name, logp, observed=y)
302291
else:
303-
return pm.DensityDist(name, logp, size=n_points) # need size? if not, dont need size arg
292+
if n_points is None:
293+
raise ValueError("When `y` is not observed, `n_points` arg is required")
294+
return pm.DensityDist(name, logp, size=n_points) # not, dont need size arg
304295

305-
def _build_conditional(self, Xnew, Xu, X, y, sigma, pred_noise):
296+
def _build_conditional(self, Xnew, Xu, X, y, sigma, cov_total, mean_total,
297+
pred_noise, diag=False):
306298
sigma2 = tt.square(sigma)
307-
Kuu = self.cov_func(Xu)
308-
Kuf = self.cov_func(Xu, X)
299+
Kuu = cov_total(Xu)
300+
Kuf = cov_total(Xu, X)
309301
Luu = cholesky(stabilize(Kuu))
310302
A = solve_lower(Luu, Kuf)
311303
Qffd = tt.sum(A * A, 0)
312304
if self.approx not in self._available_approx:
313305
raise NotImplementedError(self.approx)
314306
elif self.approx == "FITC":
315-
Kffd = self.cov_func(X, diag=True)
307+
Kffd = cov_total(X, diag=True)
316308
Lamd = tt.clip(Kffd - Qffd, 0.0, np.inf) + sigma2
317309
else: # VFE or DTC
318310
Lamd = tt.ones_like(Qffd) * sigma2
319311
A_l = A / Lamd
320312
L_B = cholesky(tt.eye(Xu.shape[0]) + tt.dot(A_l, tt.transpose(A)))
321-
r = y - self.mean_func(X)
313+
r = y - mean_total(X)
322314
r_l = r / Lamd
323315
c = solve_lower(L_B, tt.dot(A, r_l))
324316
Kus = self.cov_func(Xu, Xnew)
325317
As = solve_lower(Luu, Kus)
326-
mean = (self.mean_func(Xnew) +
327-
tt.dot(tt.transpose(As), solve_upper(tt.transpose(L_B), c)))
318+
mu = self.mean_func(Xnew) + tt.dot(tt.transpose(As), solve_upper(tt.transpose(L_B), c))
328319
C = solve_lower(L_B, As)
329-
if pred_noise:
330-
cov = (self.cov_func(Xnew) - tt.dot(tt.transpose(As), As) +
331-
tt.dot(tt.transpose(C), C) + sigma2*tt.eye(Xnew.shape[0]))
320+
if diag:
321+
Kss = self.cov_func(Xnew, diag=True)
322+
var = Kss - tt.sum(tt.sqaure(As), 0) + tt.sum(tt.square(C), 0)
323+
if pred_noise:
324+
var += sigma2
325+
return mu, var
332326
else:
333327
cov = (self.cov_func(Xnew) - tt.dot(tt.transpose(As), As) +
334328
tt.dot(tt.transpose(C), C))
335-
return mean, stabilize(cov)
336-
337-
def conditional(self, name, n_points, Xnew, Xu=None, X=None, y=None,
338-
sigma=None, pred_noise=False):
339-
if Xu is None: Xu = self.Xu
340-
if X is None: X = self.X
341-
if y is None: y = self.y
342-
if sigma is None: sigma = self.sigma
343-
mu, chol = self._build_conditional(Xnew, Xu, X, y, sigma, pred_noise)
344-
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
329+
if pred_noise:
330+
cov += sigma2 * tt.identity_like(cov)
331+
return mu, stabilize(cov)
345332

333+
def _get_cond_vals(self, other=None):
334+
if other is None:
335+
return self.X, self.Xu, self.y, self.sigma, self.cov_func, self.mean_func,
336+
else:
337+
return other.X, self.Xu, other.y, other.sigma, other.cov_func, other.mean_func
338+
339+
def conditional(self, name, n_points, Xnew, given=None, pred_noise=False):
340+
# try to get n_points from X, (via cast to int?), error if cant and n_points is none
341+
X, Xu, y, sigma, cov_total, mean_total = self._get_cond_vals(given)
342+
mu, cov = self._build_conditional(Xnew, Xu, X, y, sigma, cov_total, mean_total,
343+
pred_noise, diag=False)
344+
chol = cholesky(cov)
345+
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
346346

347+
def predict(self, Xnew, point=None, given=None, pred_noise=False, diag=False):
348+
X, Xu, y, sigma, cov_total, mean_total = self._get_cond_vals(given)
349+
mu, cov = self._build_conditional(Xnew, Xu, X, y, sigma, cov_total, mean_total,
350+
pred_noise, diag)
351+
mu, cov = draw_values([mu, cov], point=point)
352+
return mu, cov
347353

348354

0 commit comments

Comments
 (0)