Skip to content

Commit 290f43b

Browse files
sdaultonfacebook-github-bot
authored andcommitted
allow different inferred noise levels for each task in MultitaskGP (#2997)
Summary: Pull Request resolved: #2997 X-link: facebook/Ax#4233 see title Reviewed By: Balandat Differential Revision: D81373665 fbshipit-source-id: 5afc6887aef15e6641bb864b6b725c128dce10d1
1 parent c431c6d commit 290f43b

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

botorch/models/multitask.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from botorch.models.utils.assorted import get_task_value_remapping
4343
from botorch.models.utils.gpytorch_modules import (
4444
get_covar_module_with_dim_scaled_prior,
45-
get_gaussian_likelihood_with_lognormal_prior,
4645
MIN_INFERRED_NOISE_LEVEL,
4746
)
4847
from botorch.posteriors.multitask import MultitaskGPPosterior
@@ -56,6 +55,7 @@
5655
from gpytorch.kernels.index_kernel import IndexKernel
5756
from gpytorch.kernels.multitask_kernel import MultitaskKernel
5857
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
58+
from gpytorch.likelihoods.hadamard_gaussian_likelihood import HadamardGaussianLikelihood
5959
from gpytorch.likelihoods.likelihood import Likelihood
6060
from gpytorch.likelihoods.multitask_gaussian_likelihood import (
6161
MultitaskGaussianLikelihood,
@@ -212,10 +212,20 @@ def __init__(
212212
self._output_tasks = output_tasks
213213
self._num_outputs = len(output_tasks)
214214

215-
# TODO (T41270962): Support task-specific noise levels in likelihood
216215
if likelihood is None:
217216
if train_Yvar is None:
218-
likelihood = get_gaussian_likelihood_with_lognormal_prior()
217+
noise_prior = LogNormalPrior(loc=-4.0, scale=1.0)
218+
likelihood = HadamardGaussianLikelihood(
219+
num_tasks=self.num_tasks,
220+
batch_shape=torch.Size(),
221+
noise_prior=noise_prior,
222+
noise_constraint=GreaterThan(
223+
MIN_INFERRED_NOISE_LEVEL,
224+
transform=None,
225+
initial_value=noise_prior.mode,
226+
),
227+
task_feature_index=task_feature,
228+
)
219229
else:
220230
likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar.squeeze(-1))
221231

test/models/test_multitask.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from gpytorch.likelihoods import (
3131
FixedNoiseGaussianLikelihood,
3232
GaussianLikelihood,
33+
HadamardGaussianLikelihood,
3334
MultitaskGaussianLikelihood,
3435
)
3536
from gpytorch.means import ConstantMean, MultitaskMean
@@ -154,7 +155,9 @@ def test_MultiTaskGP(self) -> None:
154155
if fixed_noise:
155156
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
156157
else:
157-
self.assertIsInstance(model.likelihood, GaussianLikelihood)
158+
self.assertIsInstance(model.likelihood, HadamardGaussianLikelihood)
159+
self.assertEqual(model.likelihood.noise.shape, torch.Size([2]))
160+
self.assertEqual(model.likelihood.task_feature_index, 0)
158161
data_covar_module, task_covar_module = model.covar_module.kernels
159162
self.assertIsInstance(model.mean_module, ConstantMean)
160163
self.assertIsInstance(data_covar_module, RBFKernel)
@@ -195,8 +198,8 @@ def test_MultiTaskGP(self) -> None:
195198
torch.tensor([0.05, 0.1], **tkwargs).repeat_interleave(2)
196199
).expand(3, 4, 4)
197200
else:
198-
noise_covar = model.likelihood.noise_covar.noise * torch.eye(
199-
4, **tkwargs
201+
noise_covar = torch.diag(
202+
model.likelihood.noise_covar.noise.repeat_interleave(2)
200203
).expand(3, 4, 4)
201204
expected_y_covar = posterior_f.covariance_matrix + noise_covar
202205
self.assertTrue(
@@ -337,7 +340,7 @@ def test_MultiTaskGP_single_output(self) -> None:
337340
data_covar_module, task_covar_module = model.covar_module.kernels
338341
self.assertIsInstance(model, MultiTaskGP)
339342
self.assertEqual(model.num_outputs, 1)
340-
self.assertIsInstance(model.likelihood, GaussianLikelihood)
343+
self.assertIsInstance(model.likelihood, HadamardGaussianLikelihood)
341344
self.assertIsInstance(model.mean_module, ConstantMean)
342345
self.assertIsInstance(data_covar_module, RBFKernel)
343346
self.assertIsInstance(data_covar_module.lengthscale_prior, LogNormalPrior)

0 commit comments

Comments
 (0)