Skip to content

Commit e6ac3fb

Browse files
committed
fixes to gp classes
1 parent ba8cce6 commit e6ac3fb

File tree

1 file changed

+70
-64
lines changed

1 file changed

+70
-64
lines changed

pymc3/gp/gp.py

Lines changed: 70 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def marginal_likelihood(self, name, X, *args, **kwargs):
5555
def conditional(self, name, n_points, Xnew, *args, **kwargs):
5656
raise NotImplementedError
5757

58+
def predict(self, Xnew, point=None, given=None, diag=False):
59+
raise NotImplementedError
5860

5961
@conditioned_vars(["X", "f"])
6062
class Latent(Base):
@@ -85,33 +87,23 @@ def _get_cond_vals(self, other=None):
8587
else:
8688
return other.X, other.f, other.cov_func, other.mean_func
8789

88-
def _build_conditional(self, Xnew, X, f, cov_total, mean_total, diag=False):
90+
def _build_conditional(self, Xnew, X, f, cov_total, mean_total):
8991
Kxx = cov_total(X)
9092
Kxs = self.cov_func(X, Xnew)
9193
L = cholesky(stabilize(Kxx))
9294
A = solve_lower(L, Kxs)
9395
v = solve_lower(L, f - mean_total(X))
9496
mu = self.mean_func(Xnew) + tt.dot(tt.transpose(A), v)
95-
if diag:
96-
Kss = self.cov_func(Xnew, diag=True)
97-
cov = Kss - tt.sum(tt.square(A), 0)
98-
else:
99-
Kss = self.cov_func(Xnew)
100-
cov = Kss - tt.dot(tt.transpose(A), A)
97+
Kss = self.cov_func(Xnew)
98+
cov = Kss - tt.dot(tt.transpose(A), A)
10199
return mu, cov
102100

103-
def conditional(self, name, n_points, Xnew, gp=None):
104-
X, f, cov_total, mean_total = self._get_cond_vals(gp)
101+
def conditional(self, name, n_points, Xnew, given=None):
102+
X, f, cov_total, mean_total = self._get_cond_vals(given)
105103
mu, cov = self._build_conditional(Xnew, X, f, cov_total, mean_total)
106104
chol = cholesky(stabilize(cov))
107105
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
108106

109-
def predict(self, Xnew, point=None, gp=None, diag=False):
110-
X, f, cov_total, mean_total = self._get_cond_vals(gp)
111-
mu, cov = self._build_conditional(Xnew, X, f, cov_total, mean_total, diag)
112-
mu, cov = draw_values([mu, cov], point=point)
113-
return mu, cov
114-
115107

116108
@conditioned_vars(["X", "f", "nu"])
117109
class TP(Latent):
@@ -127,45 +119,42 @@ def __init__(self, mean_func=None, cov_func=None, nu=None):
127119
def __add__(self, other):
128120
raise ValueError("Student T processes aren't additive")
129121

130-
def _build_prior(self, name, n_points, X, nu):
122+
def _build_prior(self, name, n_points, X, reparameterize=True):
131123
mu = self.mean_func(X)
132124
chol = cholesky(stabilize(self.cov_func(X)))
133-
134-
chi2 = pm.ChiSquared("chi2_", nu)
135-
v = pm.Normal(name + "_rotated_", mu=0.0, sd=1.0, shape=n_points)
136-
f = pm.Deterministic(name, (tt.sqrt(nu) / chi2) * (mu + tt.dot(chol, v)))
125+
if reparameterize:
126+
chi2 = pm.ChiSquared("chi2_", self.nu)
127+
v = pm.Normal(name + "_rotated_", mu=0.0, sd=1.0, shape=n_points)
128+
f = pm.Deterministic(name, (tt.sqrt(self.nu) / chi2) * (mu + tt.dot(chol, v)))
129+
else:
130+
f = pm.MvStudentT(name, nu=self.nu, mu=mu, chol=chol, shape=n_points)
137131
return f
138132

139-
def prior(self, name, n_points, X, nu):
140-
f = self._build_prior(name, n_points, X, nu)
133+
def prior(self, name, n_points, X, reparameterize=True):
134+
f = self._build_prior(name, n_points, X, reparameterize)
141135
self.X = X
142-
self.nu = nu
143136
self.f = f
144137
return f
145138

146-
def _build_conditional(self, Xnew, X, f, nu):
147-
Kxx = self.cov_total(X)
139+
def _build_conditional(self, Xnew, X, f):
140+
Kxx = self.cov_func(X)
148141
Kxs = self.cov_func(X, Xnew)
149142
Kss = self.cov_func(Xnew)
150143
L = cholesky(stabilize(Kxx))
151144
A = solve_lower(L, Kxs)
152145
cov = Kss - tt.dot(tt.transpose(A), A)
153-
154-
v = solve_lower(L, f - self.mean_total(X))
146+
v = solve_lower(L, f - self.mean_func(X))
155147
mu = self.mean_func(Xnew) + tt.dot(tt.transpose(A), v)
156-
157148
beta = tt.dot(v, v)
158-
nu2 = nu + X.shape[0]
159-
160-
covT = (nu + beta - 2)/(nu2 - 2) * cov
149+
nu2 = self.nu + X.shape[0]
150+
covT = (self.nu + beta - 2)/(nu2 - 2) * cov
151+
return nu2, mu, covT
152+
153+
def conditional(self, name, n_points, Xnew):
154+
X = self.X
155+
f = self.f
156+
nu2, mu, covT = self._build_conditional(Xnew, X, f)
161157
chol = cholesky(stabilize(covT))
162-
return nu2, mu, chol
163-
164-
def conditional(self, name, n_points, Xnew, X=None, f=None, nu=None):
165-
if X is None: X = self.X
166-
if f is None: f = self.f
167-
if nu is None: nu = self.nu
168-
nu2, mu, chol = self._build_conditional(Xnew, X, f, nu)
169158
return pm.MvStudentT(name, nu=nu2, mu=mu, chol=chol, shape=n_points)
170159

171160

@@ -226,16 +215,16 @@ def _get_cond_vals(self, other=None):
226215
else:
227216
return other.X, other.y, other.noise, other.cov_func, other.mean_func
228217

229-
def conditional(self, name, n_points, Xnew, gp=None, pred_noise=False):
218+
def conditional(self, name, n_points, Xnew, given=None, pred_noise=False):
230219
# try to get n_points from X, (via cast to int?), error if cant and n_points is none
231-
X, y, noise, cov_total, mean_total = self._get_cond_vals(gp)
220+
X, y, noise, cov_total, mean_total = self._get_cond_vals(given)
232221
mu, cov = self._build_conditional(Xnew, X, y, noise, cov_total, mean_total,
233222
pred_noise, diag=False)
234223
chol = cholesky(cov)
235224
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
236225

237-
def predict(self, Xnew, point=None, gp=None, pred_noise=False, diag=False):
238-
X, y, noise, cov_total, mean_total = self._get_cond_vals(gp)
226+
def predict(self, Xnew, point=None, given=None, pred_noise=False, diag=False):
227+
X, y, noise, cov_total, mean_total = self._get_cond_vals(given)
239228
mu, cov = self._build_conditional(Xnew, X, y, noise, cov_total, mean_total,
240229
pred_noise, diag)
241230
mu, cov = draw_values([mu, cov], point=point)
@@ -291,7 +280,7 @@ def _build_marginal_likelihood_logp(self, X, Xu, y, sigma):
291280
quadratic = 0.5 * (tt.dot(r, r_l) - tt.dot(c, c))
292281
return -1.0 * (constant + logdet + quadratic + trace)
293282

294-
def marginal_likelihood(self, name, n_points, X, Xu, y, sigma, is_observed=True):
283+
def marginal_likelihood(self, name, X, Xu, y, sigma, n_points=None, is_observed=True):
295284
self.X = X
296285
self.Xu = Xu
297286
self.y = y
@@ -300,49 +289,66 @@ def marginal_likelihood(self, name, n_points, X, Xu, y, sigma, is_observed=True)
300289
if is_observed: # same thing ith n_points here?? check
301290
return pm.DensityDist(name, logp, observed=y)
302291
else:
303-
return pm.DensityDist(name, logp, size=n_points) # need size? if not, dont need size arg
292+
if n_points is None:
293+
raise ValueError("When `y` is not observed, `n_points` arg is required")
294+
return pm.DensityDist(name, logp, size=n_points) # not, dont need size arg
304295

305-
def _build_conditional(self, Xnew, Xu, X, y, sigma, pred_noise):
296+
def _build_conditional(self, Xnew, Xu, X, y, sigma, cov_total, mean_total,
297+
pred_noise, diag=False):
306298
sigma2 = tt.square(sigma)
307-
Kuu = self.cov_func(Xu)
308-
Kuf = self.cov_func(Xu, X)
299+
Kuu = cov_total(Xu)
300+
Kuf = cov_total(Xu, X)
309301
Luu = cholesky(stabilize(Kuu))
310302
A = solve_lower(Luu, Kuf)
311303
Qffd = tt.sum(A * A, 0)
312304
if self.approx not in self._available_approx:
313305
raise NotImplementedError(self.approx)
314306
elif self.approx == "FITC":
315-
Kffd = self.cov_func(X, diag=True)
307+
Kffd = cov_total(X, diag=True)
316308
Lamd = tt.clip(Kffd - Qffd, 0.0, np.inf) + sigma2
317309
else: # VFE or DTC
318310
Lamd = tt.ones_like(Qffd) * sigma2
319311
A_l = A / Lamd
320312
L_B = cholesky(tt.eye(Xu.shape[0]) + tt.dot(A_l, tt.transpose(A)))
321-
r = y - self.mean_func(X)
313+
r = y - mean_total(X)
322314
r_l = r / Lamd
323315
c = solve_lower(L_B, tt.dot(A, r_l))
324316
Kus = self.cov_func(Xu, Xnew)
325317
As = solve_lower(Luu, Kus)
326-
mean = (self.mean_func(Xnew) +
327-
tt.dot(tt.transpose(As), solve_upper(tt.transpose(L_B), c)))
318+
mu = self.mean_func(Xnew) + tt.dot(tt.transpose(As), solve_upper(tt.transpose(L_B), c))
328319
C = solve_lower(L_B, As)
329-
if pred_noise:
330-
cov = (self.cov_func(Xnew) - tt.dot(tt.transpose(As), As) +
331-
tt.dot(tt.transpose(C), C) + sigma2*tt.eye(Xnew.shape[0]))
320+
if diag:
321+
Kss = self.cov_func(Xnew, diag=True)
322+
var = Kss - tt.sum(tt.sqaure(As), 0) + tt.sum(tt.square(C), 0)
323+
if pred_noise:
324+
var += sigma2
325+
return mu, var
332326
else:
333327
cov = (self.cov_func(Xnew) - tt.dot(tt.transpose(As), As) +
334328
tt.dot(tt.transpose(C), C))
335-
return mean, stabilize(cov)
336-
337-
def conditional(self, name, n_points, Xnew, Xu=None, X=None, y=None,
338-
sigma=None, pred_noise=False):
339-
if Xu is None: Xu = self.Xu
340-
if X is None: X = self.X
341-
if y is None: y = self.y
342-
if sigma is None: sigma = self.sigma
343-
mu, chol = self._build_conditional(Xnew, Xu, X, y, sigma, pred_noise)
344-
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
329+
if pred_noise:
330+
cov += sigma2 * tt.identity_like(cov)
331+
return mu, stabilize(cov)
332+
333+
def _get_cond_vals(self, other=None):
334+
if other is None:
335+
return self.X, self.Xu, self.y, self.sigma, self.cov_func, self.mean_func,
336+
else:
337+
return other.X, self.Xu, other.y, other.sigma, other.cov_func, other.mean_func
345338

339+
def conditional(self, name, n_points, Xnew, given=None, pred_noise=False):
340+
# try to get n_points from X, (via cast to int?), error if cant and n_points is none
341+
X, Xu, y, sigma, cov_total, mean_total = self._get_cond_vals(given)
342+
mu, cov = self._build_conditional(Xnew, Xu, X, y, sigma, cov_total, mean_total,
343+
pred_noise, diag=False)
344+
chol = cholesky(cov)
345+
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
346346

347+
def predict(self, Xnew, point=None, given=None, pred_noise=False, diag=False):
348+
X, Xu, y, sigma, cov_total, mean_total = self._get_cond_vals(given)
349+
mu, cov = self._build_conditional(Xnew, Xu, X, y, sigma, cov_total, mean_total,
350+
pred_noise, diag)
351+
mu, cov = draw_values([mu, cov], point=point)
352+
return mu, cov
347353

348354

0 commit comments

Comments
 (0)