|
12 | 12 | DiagLazyTensor, |
13 | 13 | KroneckerProductDiagLazyTensor, |
14 | 14 | KroneckerProductLazyTensor, |
| 15 | + LazyEvaluatedKernelTensor, |
15 | 16 | RootLazyTensor, |
16 | 17 | ) |
17 | 18 | from ..likelihoods import Likelihood, _GaussianLikelihoodBase |
@@ -88,6 +89,10 @@ def marginal(self, function_dist, *params, **kwargs): |
88 | 89 | """ |
89 | 90 | mean, covar = function_dist.mean, function_dist.lazy_covariance_matrix |
90 | 91 |
|
| 92 | + # ensure that sumKroneckerLT is actually called |
| 93 | + if isinstance(covar, LazyEvaluatedKernelTensor): |
| 94 | + covar = covar.evaluate_kernel() |
| 95 | + |
91 | 96 | covar_kron_lt = self._shaped_noise_covar(mean.shape, add_noise=self.has_global_noise) |
92 | 97 | covar = covar + covar_kron_lt |
93 | 98 |
|
@@ -126,7 +131,7 @@ def _shaped_noise_covar(self, shape, add_noise=True, *params, **kwargs): |
126 | 131 |
|
127 | 132 | def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal: |
128 | 133 | noise = self._shaped_noise_covar(function_samples.shape, *params, **kwargs).diag() |
129 | | - noise = noise.view(*noise.shape[:-1], *function_samples.shape[-2:]) |
| 134 | + noise = noise.reshape(*noise.shape[:-1], *function_samples.shape[-2:]) |
130 | 135 | return base_distributions.Independent(base_distributions.Normal(function_samples, noise.sqrt()), 1) |
131 | 136 |
|
132 | 137 |
|
|
0 commit comments