Skip to content

Commit e8fb09a

Browse files
committed
fix n_points in Marginal.marginal_likelihood
1 parent d574edf commit e8fb09a

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pymc3/gp/gp.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,18 @@ def _build_marginal_likelihood(self, X, noise):
198198
chol = cholesky(stabilize(cov))
199199
return mu, chol
200200

201-
def marginal_likelihood(self, name, n_points, X, y, noise, is_observed=True):
201+
def marginal_likelihood(self, name, X, y, noise, n_points=None, is_observed=True):
202202
if not isinstance(noise, Covariance):
203203
noise = pm.gp.cov.WhiteNoise(noise)
204204
mu, chol = self._build_marginal_likelihood(X, noise)
205205
self.X = X
206206
self.y = y
207207
self.noise = noise
208208
if is_observed:
209-
return pm.MvNormal(name, mu=mu, chol=chol, observed=y)
209+
return pm.MvNormal.dist(mu=mu, chol=chol).logp(y)
210210
else:
211+
if n_points is None:
212+
raise ValueError("When `y` is not observed, `n_points` arg is required")
211213
return pm.MvNormal(name, mu=mu, chol=chol, size=n_points)
212214

213215
def _build_conditional(self, Xnew, X, y, noise, pred_noise):
@@ -295,7 +297,7 @@ def marginal_likelihood(self, name, n_points, X, Xu, y, sigma, is_observed=True)
295297
self.y = y
296298
self.sigma = sigma
297299
logp = lambda y: self._build_marginal_likelihood_logp(X, Xu, y, sigma)
298-
if is_observed:
300+
if is_observed: # same thing ith n_points here?? check
299301
return pm.DensityDist(name, logp, observed=y)
300302
else:
301303
return pm.DensityDist(name, logp, size=n_points) # need size? if not, dont need size arg

0 commit comments

Comments
 (0)