Skip to content

Commit 381ff30

Browse files
committed
mtgl returns sumkroneckerlt
1 parent f5d1503 commit 381ff30

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

gpytorch/lazy/sum_kronecker_lazy_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
7575
logdet_term = None
7676

7777
if inv_quad_rhs is not None:
78-
solve = self._solve(inv_quad_rhs)
78+
solve = self.inv_matmul(inv_quad_rhs)
7979
inv_quad_term = (inv_quad_rhs * solve).sum(-2)
8080

8181
if inv_quad_term.numel() and reduce_inv_quad:

gpytorch/likelihoods/multitask_gaussian_likelihood.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
DiagLazyTensor,
1313
KroneckerProductDiagLazyTensor,
1414
KroneckerProductLazyTensor,
15+
LazyEvaluatedKernelTensor,
1516
RootLazyTensor,
1617
)
1718
from ..likelihoods import Likelihood, _GaussianLikelihoodBase
@@ -88,6 +89,10 @@ def marginal(self, function_dist, *params, **kwargs):
8889
"""
8990
mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix
9091

92+
# ensure that sumKroneckerLT is actually called
93+
if isinstance(covar, LazyEvaluatedKernelTensor):
94+
covar = covar.evaluate_kernel()
95+
9196
covar_kron_lt = self._shaped_noise_covar(mean.shape, add_noise=self.has_global_noise)
9297
covar = covar + covar_kron_lt
9398

@@ -126,7 +131,7 @@ def _shaped_noise_covar(self, shape, add_noise=True, *params, **kwargs):
126131

127132
def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal:
128133
noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diag()
129-
noise = noise.view(*noise.shape[:-1], *function_samples.shape[-2:])
134+
noise = noise.reshape(*noise.shape[:-1], *function_samples.shape[-2:])
130135
return base_distributions.Independent(base_distributions.Normal(function_samples, noise.sqrt()), 1)
131136

132137

0 commit comments

Comments
 (0)