Skip to content

Commit ba8cce6

Browse files
committed
bug fixes, finalizing structure
1 parent 16a8cda commit ba8cce6

File tree

1 file changed

+35
-61
lines changed

1 file changed

+35
-61
lines changed

pymc3/gp/gp.py

Lines changed: 35 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pymc3 as pm
88
from pymc3.gp.cov import Covariance
9+
from pymc3.gp.mean import Constant
910
from pymc3.gp.util import conditioned_vars
1011
from pymc3.distributions import draw_values
1112

@@ -21,7 +22,7 @@ def stabilize(K):
2122
return K + 1e-6 * tt.identity_like(K)
2223

2324

24-
class GPBase(object):
25+
class Base(object):
2526
"""
2627
Base class
2728
"""
@@ -37,28 +38,6 @@ def __init__(self, mean_func=None, cov_func=None):
3738
self.mean_func = mean_func
3839
self.cov_func = cov_func
3940

40-
#@property
41-
#def cov_total(self):
42-
# total = getattr(self, "_cov_total", None)
43-
# if total is None:
44-
# return self.cov_func
45-
# else:
46-
# return total
47-
#@cov_total.setter
48-
#def cov_total(self, new_cov_total):
49-
# self._cov_total = new_cov_total
50-
51-
#@property
52-
#def mean_total(self):
53-
# total = getattr(self, "_mean_total", None)
54-
# if total is None:
55-
# return self.mean_func
56-
# else:
57-
# return total
58-
#@mean_total.setter
59-
#def mean_total(self, new_mean_total):
60-
# self._mean_total = new_mean_total
61-
6241
def __add__(self, other):
6342
same_attrs = set(self.__dict__.keys()) == set(other.__dict__.keys())
6443
if not isinstance(self, type(other)) and not same_attrs:
@@ -78,7 +57,7 @@ def conditional(self, name, n_points, Xnew, *args, **kwargs):
7857

7958

8059
@conditioned_vars(["X", "f"])
81-
class Latent(GPBase):
60+
class Latent(Base):
8261
""" Where the GP f isnt integrated out, and is sampled explicitly
8362
"""
8463
def __init__(self, mean_func=None, cov_func=None):
@@ -92,40 +71,46 @@ def _build_prior(self, name, n_points, X, reparameterize=True):
9271
f = pm.Deterministic(name, mu + tt.dot(chol, v))
9372
else:
9473
f = pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
95-
self.X = X
96-
self.f = f
9774
return f
9875

9976
def prior(self, name, n_points, X, reparameterize=True):
10077
f = self._build_prior(name, n_points, X, reparameterize)
78+
self.X = X
79+
self.f = f
10180
return f
10281

103-
def _build_conditional(self, Xnew, X, f):
104-
Kxx = self.cov_total(X)
82+
def _get_cond_vals(self, other=None):
83+
if other is None:
84+
return self.X, self.f, self.cov_func, self.mean_func,
85+
else:
86+
return other.X, other.f, other.cov_func, other.mean_func
87+
88+
def _build_conditional(self, Xnew, X, f, cov_total, mean_total, diag=False):
89+
Kxx = cov_total(X)
10590
Kxs = self.cov_func(X, Xnew)
106-
Kss = self.cov_func(Xnew)
10791
L = cholesky(stabilize(Kxx))
10892
A = solve_lower(L, Kxs)
109-
cov = Kss - tt.dot(tt.transpose(A), A)
110-
chol = cholesky(stabilize(cov))
111-
v = solve_lower(L, f - self.mean_total(X))
93+
v = solve_lower(L, f - mean_total(X))
11294
mu = self.mean_func(Xnew) + tt.dot(tt.transpose(A), v)
113-
return mu, chol
95+
if diag:
96+
Kss = self.cov_func(Xnew, diag=True)
97+
cov = Kss - tt.sum(tt.square(A), 0)
98+
else:
99+
Kss = self.cov_func(Xnew)
100+
cov = Kss - tt.dot(tt.transpose(A), A)
101+
return mu, cov
114102

115-
def conditional(self, name, n_points, Xnew, X=None, f=None):
116-
if X is None: X = self.X
117-
if f is None: f = self.f
118-
mu, chol = self._build_conditional(Xnew, X, f)
103+
def conditional(self, name, n_points, Xnew, gp=None):
104+
X, f, cov_total, mean_total = self._get_cond_vals(gp)
105+
mu, cov = self._build_conditional(Xnew, X, f, cov_total, mean_total)
106+
chol = cholesky(stabilize(cov))
119107
return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
120108

121-
def conditional2(self, name, n_points, X_new, gp_fitted=None):
122-
# cant condition on a gp that hasnt been fit, so gp should have gp.f
123-
# the changing X thing is dumb, get rid of it?
124-
125-
if gp is not None:
126-
X = gp.X
127-
f = gp.f
128-
109+
def predict(self, Xnew, point=None, gp=None, diag=False):
110+
X, f, cov_total, mean_total = self._get_cond_vals(gp)
111+
mu, cov = self._build_conditional(Xnew, X, f, cov_total, mean_total, diag)
112+
mu, cov = draw_values([mu, cov], point=point)
113+
return mu, cov
129114

130115

131116
@conditioned_vars(["X", "f", "nu"])
@@ -149,14 +134,13 @@ def _build_prior(self, name, n_points, X, nu):
149134
chi2 = pm.ChiSquared("chi2_", nu)
150135
v = pm.Normal(name + "_rotated_", mu=0.0, sd=1.0, shape=n_points)
151136
f = pm.Deterministic(name, (tt.sqrt(nu) / chi2) * (mu + tt.dot(chol, v)))
152-
153-
self.X = X
154-
self.f = f
155-
self.nu = nu
156137
return f
157138

158139
def prior(self, name, n_points, X, nu):
159140
f = self._build_prior(name, n_points, X, nu)
141+
self.X = X
142+
self.nu = nu
143+
self.f = f
160144
return f
161145

162146
def _build_conditional(self, Xnew, X, f, nu):
@@ -186,7 +170,7 @@ def conditional(self, name, n_points, Xnew, X=None, f=None, nu=None):
186170

187171

188172
@conditioned_vars(["X", "y", "noise"])
189-
class Marginal(GPBase):
173+
class Marginal(Base):
190174

191175
def __init__(self, mean_func=None, cov_func=None):
192176
super(Marginal, self).__init__(mean_func, cov_func)
@@ -215,7 +199,6 @@ def marginal_likelihood(self, name, X, y, noise, n_points=None, is_observed=True
215199

216200
def _build_conditional(self, Xnew, X, y, noise, cov_total, mean_total,
217201
pred_noise, diag=False):
218-
# when not conditioning on another gp, cov_total = self.cov_func
219202
Kxx = cov_total(X)
220203
Kxs = self.cov_func(X, Xnew)
221204
Knx = noise(X)
@@ -258,18 +241,9 @@ def predict(self, Xnew, point=None, gp=None, pred_noise=False, diag=False):
258241
mu, cov = draw_values([mu, cov], point=point)
259242
return mu, cov
260243

261-
#def conditional(self, name, n_points, Xnew, X=None, y=None,
262-
# noise=None, pred_noise=False):
263-
# X, y, noise = self._get_cond_vals(X, y, noise)
264-
# mu, cov = self._build_conditional(Xnew, X, y, noise, pred_noise, diag=False)
265-
# chol = cholesky(cov)
266-
# return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
267-
268-
269-
270244

271245
@conditioned_vars(["X", "Xu", "y", "sigma"])
272-
class MarginalSparse(GPBase):
246+
class MarginalSparse(Base):
273247
_available_approx = ["FITC", "VFE", "DTC"]
274248
""" FITC and VFE sparse approximations
275249
"""

0 commit comments

Comments
 (0)