Skip to content

Commit 8fcdb5d

Browse files
committed
try to fix test error again
1 parent 6164676 commit 8fcdb5d

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

pymc3/gp/gp.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ def conditional(self, name, Xnew, given={}, **kwargs):
192192
constructor.
193193
"""
194194

195-
X, f, cov_total, mean_total = self._get_given_vals(given)
196-
mu, cov = self._build_conditional(Xnew, X, f, cov_total, mean_total)
195+
givens = self._get_given_vals(given)
196+
mu, cov = self._build_conditional(Xnew, *givens)
197197
chol = cholesky(stabilize(cov))
198198
shape = infer_shape(Xnew, kwargs.pop("shape", None))
199199
return pm.MvNormal(name, mu=mu, chol=chol, shape=shape, **kwargs)
@@ -423,8 +423,8 @@ def _get_given_vals(self, given):
423423
X, y, noise = self.X, self.y, self.noise
424424
return X, y, noise, cov_total, mean_total
425425

426-
def _build_conditional(self, Xnew, X, y, noise, cov_total, mean_total,
427-
pred_noise, diag=False):
426+
def _build_conditional(self, Xnew, pred_noise, diag, X, y, noise,
427+
cov_total, mean_total):
428428
Kxx = cov_total(X)
429429
Kxs = self.cov_func(X, Xnew)
430430
Knx = noise(X)
@@ -478,9 +478,8 @@ def conditional(self, name, Xnew, pred_noise=False, given={}, **kwargs):
478478
constructor.
479479
"""
480480

481-
X, y, noise, cov_total, mean_total = self._get_given_vals(given)
482-
mu, cov = self._build_conditional(Xnew, X, y, noise, cov_total, mean_total,
483-
pred_noise, diag=False)
481+
givens = self._get_given_vals(given)
482+
mu, cov = self._build_conditional(Xnew, pred_noise, False, *givens)
484483
chol = cholesky(cov)
485484
shape = infer_shape(Xnew, kwargs.pop("shape", None))
486485
return pm.MvNormal(name, mu=mu, chol=chol, shape=shape, **kwargs)
@@ -531,9 +530,8 @@ def predictt(self, Xnew, diag=False, pred_noise=False, given={}):
531530
Same as `conditional` method.
532531
"""
533532

534-
X, y, noise, cov_total, mean_total = self._get_given_vals(given)
535-
mu, cov = self._build_conditional(Xnew, X, y, noise, cov_total,
536-
mean_total, pred_noise, diag)
533+
givens = self._get_given_vals(given)
534+
mu, cov = self._build_conditional(Xnew, pred_noise, diag, *givens)
537535
return mu, cov
538536

539537

@@ -680,8 +678,7 @@ def marginal_likelihood(self, name, X, Xu, y, sigma, is_observed=True, **kwargs)
680678
shape = infer_shape(X, kwargs.pop("shape", None))
681679
return pm.DensityDist(name, logp, shape=shape, **kwargs)
682680

683-
def _build_conditional(self, Xnew, X, Xu, y, sigma, cov_total, mean_total,
684-
pred_noise, diag=False):
681+
def _build_conditional(self, Xnew, pred_noise, diag, X, Xu, y, sigma, cov_total, mean_total):
685682
sigma2 = tt.square(sigma)
686683
Kuu = cov_total(Xu)
687684
Kuf = cov_total(Xu, X)
@@ -752,10 +749,8 @@ def conditional(self, name, Xnew, pred_noise=False, given={}, **kwargs):
752749
constructor.
753750
"""
754751

755-
X, Xu, y, sigma, cov_total, mean_total = self._get_given_vals(given)
756-
mu, cov = self._build_conditional(Xnew, X, Xu, y, sigma, cov_total,
757-
mean_total, pred_noise, diag=False)
752+
givens = self._get_given_vals(given)
753+
mu, cov = self._build_conditional(Xnew, pred_noise, False, *givens)
758754
chol = cholesky(cov)
759755
shape = infer_shape(Xnew, kwargs.pop("shape", None))
760756
return pm.MvNormal(name, mu=mu, chol=chol, shape=shape, **kwargs)
761-

0 commit comments

Comments
 (0)