Skip to content

Commit a6eddcb

Browse files
sdaultonfacebook-github-bot
authored andcommitted
require output dim in MultiTaskGP (#383)
Summary: Pull Request resolved: #383 - Require output dim in MTGP - Validate training tensor dimensions - Validate input scaling Reviewed By: Balandat Differential Revision: D20223856 fbshipit-source-id: 8dfd327ecc6a9bb141211dd43148baa34c22ed70
1 parent 4302a0c commit a6eddcb

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

botorch/models/multitask.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torch import Tensor
2828

2929
from .gpytorch import MultiTaskGPyTorchModel
30+
from .utils import validate_input_scaling
3031

3132

3233
class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel):
@@ -74,10 +75,13 @@ def __init__(
7475
>>> train_Y = torch.cat(f1(X1), f2(X2))
7576
>>> model = MultiTaskGP(train_X, train_Y, task_feature=-1)
7677
"""
77-
# TODO: Validate input normalization/scaling
78-
if train_X.ndimension() != 2:
78+
self._validate_tensor_args(X=train_X, Y=train_Y)
79+
validate_input_scaling(train_X=train_X, train_Y=train_Y)
80+
if train_X.ndim != 2:
7981
# Currently, batch mode MTGPs are blocked upstream in GPyTorch
8082
raise ValueError(f"Unsupported shape {train_X.shape} for train_X.")
83+
# squeeze output dim
84+
train_Y = train_Y.squeeze(-1)
8185
d = train_X.shape[-1] - 1
8286
if not (-d <= task_feature <= d):
8387
raise ValueError(f"Must have that -{d} <= task_feature <= {d}")
@@ -199,6 +203,7 @@ def __init__(
199203
>>> train_Yvar = 0.1 + 0.1 * torch.rand_like(train_Y)
200204
>>> model = FixedNoiseMultiTaskGP(train_X, train_Y, train_Yvar, -1)
201205
"""
206+
self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar)
202207
# We'll instatiate a MultiTaskGP and simply override the likelihood
203208
super().__init__(
204209
train_X=train_X,
@@ -207,5 +212,5 @@ def __init__(
207212
output_tasks=output_tasks,
208213
rank=rank,
209214
)
210-
self.likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar)
215+
self.likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar.squeeze(-1))
211216
self.to(train_X)

test/models/test_multitask.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _get_random_mt_data(**tkwargs):
3131
full_train_i = torch.cat([train_i_task1, train_i_task2])
3232
full_train_y = torch.cat([train_y1, train_y2])
3333
train_X = torch.stack([full_train_x, full_train_i.type_as(full_train_x)], dim=-1)
34-
train_Y = full_train_y
34+
train_Y = full_train_y.unsqueeze(-1) # add output dim
3535
return train_X, train_Y
3636

3737

@@ -121,7 +121,7 @@ def test_MultiTaskGP(self):
121121

122122
# test that unsupported batch shape MTGPs throw correct error
123123
with self.assertRaises(ValueError):
124-
MultiTaskGP(torch.rand(2, 2, 2), torch.rand(2, 1), 0)
124+
MultiTaskGP(torch.rand(2, 2, 2), torch.rand(2, 2, 1), 0)
125125

126126
# test that bad feature index throws correct error
127127
train_X, train_Y = _get_random_mt_data(**tkwargs)
@@ -233,7 +233,7 @@ def test_FixedNoiseMultiTaskGP(self):
233233
# test that unsupported batch shape MTGPs throw correct error
234234
with self.assertRaises(ValueError):
235235
FixedNoiseMultiTaskGP(
236-
torch.rand(2, 2, 2), torch.rand(2, 1), torch.rand(2, 1), 0
236+
torch.rand(2, 2, 2), torch.rand(2, 2, 1), torch.rand(2, 2, 1), 0
237237
)
238238

239239
# test that bad feature index throws correct error

0 commit comments

Comments
 (0)