|
30 | 30 | from gpytorch.likelihoods import (
|
31 | 31 | FixedNoiseGaussianLikelihood,
|
32 | 32 | GaussianLikelihood,
|
| 33 | + HadamardGaussianLikelihood, |
33 | 34 | MultitaskGaussianLikelihood,
|
34 | 35 | )
|
35 | 36 | from gpytorch.means import ConstantMean, MultitaskMean
|
@@ -154,7 +155,9 @@ def test_MultiTaskGP(self) -> None:
|
154 | 155 | if fixed_noise:
|
155 | 156 | self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
|
156 | 157 | else:
|
157 |
| - self.assertIsInstance(model.likelihood, GaussianLikelihood) |
| 158 | + self.assertIsInstance(model.likelihood, HadamardGaussianLikelihood) |
| 159 | + self.assertEqual(model.likelihood.noise.shape, torch.Size([2])) |
| 160 | + self.assertEqual(model.likelihood.task_feature_index, 0) |
158 | 161 | data_covar_module, task_covar_module = model.covar_module.kernels
|
159 | 162 | self.assertIsInstance(model.mean_module, ConstantMean)
|
160 | 163 | self.assertIsInstance(data_covar_module, RBFKernel)
|
@@ -195,8 +198,8 @@ def test_MultiTaskGP(self) -> None:
|
195 | 198 | torch.tensor([0.05, 0.1], **tkwargs).repeat_interleave(2)
|
196 | 199 | ).expand(3, 4, 4)
|
197 | 200 | else:
|
198 |
| - noise_covar = model.likelihood.noise_covar.noise * torch.eye( |
199 |
| - 4, **tkwargs |
| 201 | + noise_covar = torch.diag( |
| 202 | + model.likelihood.noise_covar.noise.repeat_interleave(2) |
200 | 203 | ).expand(3, 4, 4)
|
201 | 204 | expected_y_covar = posterior_f.covariance_matrix + noise_covar
|
202 | 205 | self.assertTrue(
|
@@ -337,7 +340,7 @@ def test_MultiTaskGP_single_output(self) -> None:
|
337 | 340 | data_covar_module, task_covar_module = model.covar_module.kernels
|
338 | 341 | self.assertIsInstance(model, MultiTaskGP)
|
339 | 342 | self.assertEqual(model.num_outputs, 1)
|
340 |
| - self.assertIsInstance(model.likelihood, GaussianLikelihood) |
| 343 | + self.assertIsInstance(model.likelihood, HadamardGaussianLikelihood) |
341 | 344 | self.assertIsInstance(model.mean_module, ConstantMean)
|
342 | 345 | self.assertIsInstance(data_covar_module, RBFKernel)
|
343 | 346 | self.assertIsInstance(data_covar_module.lengthscale_prior, LogNormalPrior)
|
|
0 commit comments