Skip to content

Commit 24b464f

Browse files
committed
fixes to how given params work, cleanup
1 parent 32006a4 commit 24b464f

File tree

1 file changed

+69
-71
lines changed

1 file changed

+69
-71
lines changed

pymc3/gp/gp.py

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,13 @@
77
import pymc3 as pm
88
from pymc3.gp.cov import Covariance
99
from pymc3.gp.mean import Constant
10-
from pymc3.gp.util import conditioned_vars
10+
from pymc3.gp.util import (conditioned_vars,
11+
infer_shape, stabilize, cholesky, solve, solve_lower, solve_upper)
1112
from pymc3.distributions import draw_values
1213

1314
__all__ = ['Latent', 'Marginal', 'TP', 'MarginalSparse']
1415

1516

16-
cholesky = pm.distributions.dist_math.Cholesky(nofail=True, lower=True)
17-
solve_lower = tt.slinalg.Solve(A_structure='lower_triangular')
18-
solve_upper = tt.slinalg.Solve(A_structure='upper_triangular')
19-
20-
def stabilize(K):
21-
""" adds small diagonal to a covariance matrix """
22-
return K + 1e-6 * tt.identity_like(K)
23-
24-
2517
class Base(object):
2618
"""
2719
Base class
@@ -58,14 +50,15 @@ def conditional(self, name, n_points, Xnew, *args, **kwargs):
5850
def predict(self, Xnew, point=None, given=None, diag=False):
5951
raise NotImplementedError
6052

53+
6154
@conditioned_vars(["X", "f"])
6255
class Latent(Base):
6356
""" Where the GP f isnt integrated out, and is sampled explicitly
6457
"""
6558
def __init__(self, mean_func=None, cov_func=None):
6659
super(Latent, self).__init__(mean_func, cov_func)
6760

68-
def _build_prior(self, name, n_points, X, reparameterize=True):
61+
def _build_prior(self, name, X, n_points, reparameterize=True):
6962
mu = self.mean_func(X)
7063
chol = cholesky(stabilize(self.cov_func(X)))
7164
if reparameterize:
@@ -75,17 +68,25 @@ def _build_prior(self, name, n_points, X, reparameterize=True):
7568
f = pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
7669
return f
7770

78-
def prior(self, name, n_points, X, reparameterize=True):
79-
f = self._build_prior(name, n_points, X, reparameterize)
71+
def prior(self, name, X, n_points=None, reparameterize=True):
72+
n_points = infer_shape(X, n_points)
73+
f = self._build_prior(name, X, n_points, reparameterize)
8074
self.X = X
8175
self.f = f
8276
return f
8377

84-
def _get_cond_vals(self, other=None):
85-
if other is None:
86-
return self.X, self.f, self.cov_func, self.mean_func,
78+
def _get_given_vals(self, **given):
79+
if 'gp' in given:
80+
cov_total = given['gp'].cov_func
81+
mean_total = given['gp'].mean_func
82+
else:
83+
cov_total = self.cov_func
84+
mean_total = self.mean_func
85+
if all(val in given for val in ['X', 'f']):
86+
X, f = given['X'], given['f']
8787
else:
88-
return other.X, other.f, other.cov_func, other.mean_func
88+
X, f = self.X, self.f
89+
return X, f, cov_total, mean_total
8990

9091
def _build_conditional(self, Xnew, X, f, cov_total, mean_total):
9192
Kxx = cov_total(X)
@@ -98,10 +99,11 @@ def _build_conditional(self, Xnew, X, f, cov_total, mean_total):
9899
cov = Kss - tt.dot(tt.transpose(A), A)
99100
return mu, cov
100101

101-
def conditional(self, name, n_points, Xnew, given=None):
102-
X, f, cov_total, mean_total = self._get_cond_vals(given)
103-
mu, cov = self._build_conditional(Xnew, X, f, cov_total, mean_total)
102+
def conditional(self, name, Xnew, n_points=None, given=None):
103+
givens = self._get_given_vals(**given)
104+
mu, cov = self._build_conditional(Xnew, *givens)
104105
chol = cholesky(stabilize(cov))
106+
n_points = infer_shape(Xnew, n_points)
105107
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
106108

107109

@@ -150,11 +152,12 @@ def _build_conditional(self, Xnew, X, f):
150152
covT = (self.nu + beta - 2)/(nu2 - 2) * cov
151153
return nu2, mu, covT
152154

153-
def conditional(self, name, n_points, Xnew):
155+
def conditional(self, name, Xnew, n_points=None):
154156
X = self.X
155157
f = self.f
156158
nu2, mu, covT = self._build_conditional(Xnew, X, f)
157159
chol = cholesky(stabilize(covT))
160+
n_points = infer_shape(Xnew, n_points)
158161
return pm.MvStudentT(name, nu=nu2, mu=mu, chol=chol, shape=n_points)
159162

160163

@@ -182,10 +185,24 @@ def marginal_likelihood(self, name, X, y, noise, n_points=None, is_observed=True
182185
if is_observed:
183186
return pm.MvNormal(name, mu=mu, chol=chol, observed=y)
184187
else:
185-
if n_points is None:
186-
raise ValueError("When `y` is not observed, `n_points` arg is required")
188+
n_points = infer_shape(X, n_points)
187189
return pm.MvNormal(name, mu=mu, chol=chol, size=n_points)
188190

191+
def _get_given_vals(self, **given):
192+
if 'gp' in given:
193+
cov_total = given['gp'].cov_func
194+
mean_total = given['gp'].mean_func
195+
else:
196+
cov_total = self.cov_func
197+
mean_total = self.mean_func
198+
if all(val in given for val in ['X', 'y', 'noise']):
199+
X, y, noise = given['X'], given['y'], given['noise']
200+
if not isinstance(noise, Covariance):
201+
noise = pm.gp.cov.WhiteNoise(noise)
202+
else:
203+
X, y, noise = self.X, self.y, self.noise
204+
return X, y, noise, cov_total, mean_total
205+
189206
def _build_conditional(self, Xnew, X, y, noise, cov_total, mean_total,
190207
pred_noise, diag=False):
191208
Kxx = cov_total(X)
@@ -209,38 +226,32 @@ def _build_conditional(self, Xnew, X, y, noise, cov_total, mean_total,
209226
cov += noise(Xnew)
210227
return mu, stabilize(cov)
211228

212-
def _get_cond_vals(self, other=None):
213-
# LOOK AT THIS MORE, where to X, y, noise need to come from? depending on situation
214-
# provide this function with **kwargs and return those if given? X=X, y=y, etc
215-
# i think this would be good. could build the gp "from scratch". caching these
216-
# from prior or marglike is a convenience
217-
if other is None:
218-
return self.X, self.y, self.noise, self.cov_func, self.mean_func,
219-
else:
220-
return other.X, other.y, other.noise, other.cov_func, other.mean_func
221-
222-
def conditional(self, name, n_points, Xnew, given=None, pred_noise=False):
223-
# try to get n_points from X, (via cast to int?), error if cant and n_points is none
224-
X, y, noise, cov_total, mean_total = self._get_cond_vals(given)
225-
mu, cov = self._build_conditional(Xnew, X, y, noise, cov_total, mean_total,
226-
pred_noise, diag=False)
229+
def conditional(self, name, Xnew, pred_noise=False, n_points=None, **given):
230+
givens = self._get_given_vals(**given)
231+
mu, cov = self._build_conditional(Xnew, *givens, pred_noise, diag=False)
227232
chol = cholesky(cov)
233+
n_points = infer_shape(Xnew, n_points)
228234
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
229235

230-
def predict(self, Xnew, point=None, given=None, pred_noise=False, diag=False):
231-
X, y, noise, cov_total, mean_total = self._get_cond_vals(given)
232-
mu, cov = self._build_conditional(Xnew, X, y, noise, cov_total, mean_total,
233-
pred_noise, diag)
236+
def predict(self, Xnew, point=None, diag=False, pred_noise=False, **given):
237+
mu, cov = self.predictt(Xnew, diag, pred_noise, **given)
234238
mu, cov = draw_values([mu, cov], point=point)
235239
return mu, cov
236240

241+
def predictt(self, Xnew, diag=False, pred_noise=False, **given):
242+
givens = self._get_given_vals(**given)
243+
mu, cov = self._build_conditional(Xnew, *givens, pred_noise, diag)
244+
return mu, cov
245+
237246

238247
@conditioned_vars(["X", "Xu", "y", "sigma"])
239-
class MarginalSparse(Base):
248+
class MarginalSparse(Marginal):
240249
_available_approx = ["FITC", "VFE", "DTC"]
241250
""" FITC and VFE sparse approximations
242251
"""
243252
def __init__(self, mean_func=None, cov_func=None, approx="FITC"):
253+
if approx not in self._available_approx:
254+
raise NotImplementedError(approx)
244255
self.approx = approx
245256
super(MarginalSparse, self).__init__(mean_func, cov_func)
246257

@@ -260,9 +271,7 @@ def _build_marginal_likelihood_logp(self, X, Xu, y, sigma):
260271
Luu = cholesky(stabilize(Kuu))
261272
A = solve_lower(Luu, Kuf)
262273
Qffd = tt.sum(A * A, 0)
263-
if self.approx not in self._available_approx:
264-
raise NotImplementedError(self.approx)
265-
elif self.approx == "FITC":
274+
if self.approx == "FITC":
266275
Kffd = self.cov_func(X, diag=True)
267276
Lamd = tt.clip(Kffd - Qffd, 0.0, np.inf) + sigma2
268277
trace = 0.0
@@ -293,9 +302,8 @@ def marginal_likelihood(self, name, X, Xu, y, sigma, n_points=None, is_observed=
293302
if is_observed: # same thing ith n_points here?? check
294303
return pm.DensityDist(name, logp, observed=y)
295304
else:
296-
if n_points is None:
297-
raise ValueError("When `y` is not observed, `n_points` arg is required")
298-
return pm.DensityDist(name, logp, size=n_points) # not, dont need size arg
305+
n_points = infer_shape(X, n_points)
306+
return pm.DensityDist(name, logp, size=n_points)
299307

300308
def _build_conditional(self, Xnew, Xu, X, y, sigma, cov_total, mean_total,
301309
pred_noise, diag=False):
@@ -305,9 +313,7 @@ def _build_conditional(self, Xnew, Xu, X, y, sigma, cov_total, mean_total,
305313
Luu = cholesky(stabilize(Kuu))
306314
A = solve_lower(Luu, Kuf)
307315
Qffd = tt.sum(A * A, 0)
308-
if self.approx not in self._available_approx:
309-
raise NotImplementedError(self.approx)
310-
elif self.approx == "FITC":
316+
if self.approx == "FITC":
311317
Kffd = cov_total(X, diag=True)
312318
Lamd = tt.clip(Kffd - Qffd, 0.0, np.inf) + sigma2
313319
else: # VFE or DTC
@@ -334,25 +340,17 @@ def _build_conditional(self, Xnew, Xu, X, y, sigma, cov_total, mean_total,
334340
cov += sigma2 * tt.identity_like(cov)
335341
return mu, stabilize(cov)
336342

337-
def _get_cond_vals(self, other=None):
338-
if other is None:
339-
return self.X, self.Xu, self.y, self.sigma, self.cov_func, self.mean_func,
343+
def _get_given_vals(self, **given):
344+
if 'gp' in given:
345+
cov_total = given['gp'].cov_func
346+
mean_total = given['gp'].mean_func
340347
else:
341-
return other.X, self.Xu, other.y, other.sigma, other.cov_func, other.mean_func
342-
343-
def conditional(self, name, n_points, Xnew, given=None, pred_noise=False):
344-
# try to get n_points from X, (via cast to int?), error if cant and n_points is none
345-
X, Xu, y, sigma, cov_total, mean_total = self._get_cond_vals(given)
346-
mu, cov = self._build_conditional(Xnew, Xu, X, y, sigma, cov_total, mean_total,
347-
pred_noise, diag=False)
348-
chol = cholesky(cov)
349-
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
350-
351-
def predict(self, Xnew, point=None, given=None, pred_noise=False, diag=False):
352-
X, Xu, y, sigma, cov_total, mean_total = self._get_cond_vals(given)
353-
mu, cov = self._build_conditional(Xnew, Xu, X, y, sigma, cov_total, mean_total,
354-
pred_noise, diag)
355-
mu, cov = draw_values([mu, cov], point=point)
356-
return mu, cov
348+
cov_total = self.cov_func
349+
mean_total = self.mean_func
350+
if all(val in given for val in ['X', 'Xu', 'y', 'sigma']):
351+
X, Xu, y, sigma = given['X'], given['Xu'], given['y'], given['sigma']
352+
else:
353+
X, Xu, y, sigma = self.X, self.y, self.sigma
354+
return X, y, sigma, cov_total, mean_total
357355

358356

0 commit comments

Comments
 (0)