Skip to content

Commit 222f487

Browse files
committed
fix bug in latent, move appox arg to __init__
1 parent 15ff63f commit 222f487

File tree

1 file changed

+14
-17
lines changed

1 file changed

+14
-17
lines changed

pymc3/gp/gp.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -193,31 +193,28 @@ def _prior_rv(self, X, y, sigma=None, cov_func_noise=None):
193193

194194
def _predictive_rv(self, X, y, Xs, sigma=None, cov_func_noise=None):
195195
cov_func_noise = self._to_noise_func(sigma, cov_func_noise)
196-
mu, chol = self._build_predictive(X, Xs, y, cov_func_noise)
196+
mu, chol = self._build_predictive(X, y, Xs, cov_func_noise)
197197
return pm.MvNormal(self.name, mu=mu, chol=chol, shape=self.size)
198198

199199

200200

201201
class GPMarginalSparse(GPBase):
202202
""" FITC and VFE sparse approximations
203203
"""
204-
def __init__(self, cov_func):
204+
def __init__(self, cov_func, approx=None):
205+
if approx is None:
206+
approx = "FITC"
207+
self.approx = approx
205208
super(GPMarginalSparse, self).__init__(cov_func)
206209

207-
def __call__(self, name, size, mean_func, include_noise=False, approx=None):
208-
if hasattr(self, "approx") and self.approx != approx:
209-
raise ValueError("dont use diff approx (should this be a warning?)")
210-
else:
211-
if approx is None:
212-
pm._log.info("Using FITC approximation for {}".format(name))
213-
approx = "FITC"
214-
else:
215-
approx = approx.upper()
216-
if approx not in ["VFE", "FITC"]:
217-
raise ValueError(("'FITC' or 'VFE' are the supported GP "
218-
"approximations, not {}".format(approx)))
219-
self.approx = approx
210+
# overriding __add__, since whether its vfe or fitc determines its 'type'
211+
def __add__(self, other):
212+
if not (isinstance(self, type(other)) and self.approx == other.approx):
213+
raise ValueError("cant add different GP types")
214+
cov_func = self.cov_func + other.cov_func
215+
return type(self)(cov_func)
220216

217+
def __call__(self, name, size, mean_func, include_noise=False):
221218
self.include_noise = include_noise
222219
return super(GPMarginalSparse, self).__call__(name, size, mean_func)
223220

@@ -240,7 +237,7 @@ def kmeans_inducing_points(self, n_inducing, X):
240237
Xu, distortion = kmeans(Xw, n_inducing)
241238
return Xu * scaling
242239

243-
def _build_prior(self, X, Xu, y, sigma):
240+
def _build_prior_logp(self, X, Xu, y, sigma):
244241
sigma2 = tt.square(sigma)
245242
Kuu = self.cov_func(Xu)
246243
Kuf = self.cov_func(Xu, X)
@@ -257,7 +254,7 @@ def _build_prior(self, X, Xu, y, sigma):
257254
(tt.sum(self.cov_func(X, diag=True)) -
258255
tt.sum(tt.sum(A * A, 0))))
259256
else:
260-
raise NotImplementedError(approx)
257+
raise NotImplementedError(self.approx)
261258
A_l = A / Lamd
262259
L_B = cholesky(tt.eye(Xu.shape[0]) + tt.dot(A_l, tt.transpose(A)))
263260
r = y - self.mean_func(X)

0 commit comments

Comments
 (0)