Skip to content

Commit 5ab36b1

Browse files
esantorellafacebook-github-bot
authored andcommitted
Follow-ups from D40783106 (fantasize refactor) (#1479)
Summary: see T137147547 Remove inappropriate subclassing of `SingleTaskGP` (and clean up associated to-dos) [x] Make `HeteroskedasticSingleTaskGP` not a subclass of `SingleTaskGP` since it can't `fantasize` [x] Make `SaasFullyBayesianSingleTaskGP` not a subclass of `SingleTaskGP` since it can't `fantasize` [x] Fix downstream problems in multiple dispatched this caused.... Add `fantasize` back to classes that weren't using it, but can call it and used to have it [x] Make `MultiTaskGP` have a `fantasize` method (debatable! it used to have one, but wasn't used) [x] Restore `fantasize` method with `NotImplementedError` to `SaasFullyBayesianMultiTaskGP` (as a result of adding `fantasize` to `MultiTaskGP`) [x] Make `KroneckerMultiTaskGP` have a `fantasize` method (debatable! it used to have one, but wasn't used) Pull Request resolved: #1479 Reviewed By: saitcakmak Differential Revision: D41053234 Pulled By: esantorella fbshipit-source-id: 0bf473655cc87061e87406b620dc604a370d3388
1 parent 9d90635 commit 5ab36b1

File tree

6 files changed

+63
-20
lines changed

6 files changed

+63
-20
lines changed

botorch/models/fully_bayesian.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import pyro
3939
import torch
4040
from botorch.acquisition.objective import PosteriorTransform
41-
from botorch.models.gp_regression import SingleTaskGP
41+
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
4242
from botorch.models.transforms.input import InputTransform
4343
from botorch.models.transforms.outcome import OutcomeTransform
4444
from botorch.models.utils import validate_input_scaling
@@ -54,6 +54,7 @@
5454
from gpytorch.likelihoods.likelihood import Likelihood
5555
from gpytorch.means.constant_mean import ConstantMean
5656
from gpytorch.means.mean import Mean
57+
from gpytorch.models.exact_gp import ExactGP
5758
from torch import Tensor
5859

5960
MIN_INFERRED_NOISE_LEVEL = 1e-6
@@ -294,7 +295,7 @@ def load_mcmc_samples(
294295
return mean_module, covar_module, likelihood
295296

296297

297-
class SaasFullyBayesianSingleTaskGP(SingleTaskGP):
298+
class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel):
298299
r"""A fully Bayesian single-task GP model with the SAAS prior.
299300
300301
This model assumes that the inputs have been normalized to [0, 1]^d and that
@@ -364,7 +365,7 @@ def __init__(
364365
train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL)
365366

366367
X_tf, Y_tf, _ = self._transform_tensor_args(X=train_X, Y=train_Y)
367-
super(SingleTaskGP, self).__init__(
368+
super().__init__(
368369
train_inputs=X_tf, train_targets=Y_tf, likelihood=GaussianLikelihood()
369370
)
370371
self.mean_module = None
@@ -473,9 +474,19 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
473474
super().load_state_dict(state_dict=state_dict, strict=strict)
474475

475476
def forward(self, X: Tensor) -> MultivariateNormal:
477+
"""
478+
Unlike in other classes' `forward` methods, there is no `if self.training`
479+
block, because it ought to be unreachable: If `self.train()` has been called,
480+
then `self.covar_module` will be None, `check_if_fitted()` will fail, and the
481+
rest of this method will not run.
482+
"""
476483
self._check_if_fitted()
477-
return super().forward(X.unsqueeze(MCMC_DIM))
484+
x = X.unsqueeze(MCMC_DIM)
485+
mean_x = self.mean_module(x)
486+
covar_x = self.covar_module(x)
487+
return MultivariateNormal(mean_x, covar_x)
478488

489+
# pyre-ignore[14]: Inconsistent override
479490
def posterior(
480491
self,
481492
X: Tensor,

botorch/models/fully_bayesian_multitask.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010

11-
from typing import Any, Dict, List, Optional, Tuple
11+
from typing import Any, Dict, List, NoReturn, Optional, Tuple
1212

1313
import pyro
1414
import torch
@@ -299,6 +299,9 @@ def batch_shape(self) -> torch.Size:
299299
self._check_if_fitted()
300300
return torch.Size([self.num_mcmc_samples])
301301

302+
def fantasize(self, *args, **kwargs) -> NoReturn:
303+
raise NotImplementedError("Fantasize is not implemented!")
304+
302305
def _check_if_fitted(self):
303306
r"""Raise an exception if the model hasn't been fitted."""
304307
if self.covar_module is None:
@@ -321,6 +324,8 @@ def load_mcmc_samples(self, mcmc_samples: Dict[str, Tensor]) -> None:
321324
self.latent_features,
322325
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
323326

327+
# pyre-fixme[14]: Inconsistent override of
328+
# BatchedMultiOutputGPyTorchModel.posterior
324329
def posterior(
325330
self,
326331
X: Tensor,
@@ -345,6 +350,7 @@ def posterior(
345350
posterior = FullyBayesianPosterior(mvn=posterior.mvn)
346351
return posterior
347352

353+
# pyre-fixme[14]: Inconsistent override
348354
def forward(self, X: Tensor) -> MultivariateNormal:
349355
self._check_if_fitted()
350356
X = X.unsqueeze(MCMC_DIM)
@@ -373,6 +379,7 @@ def forward(self, X: Tensor) -> MultivariateNormal:
373379
return MultivariateNormal(mean_x, covar)
374380

375381
@classmethod
382+
# pyre-fixme[14]: Inconsistent override of `MultiTaskGP.construct_inputs`
376383
def construct_inputs(
377384
cls,
378385
training_data: Dict[str, SupervisedDataset],

botorch/models/gp_regression.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:
360360
return new_model
361361

362362

363-
class HeteroskedasticSingleTaskGP(SingleTaskGP):
363+
class HeteroskedasticSingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP):
364364
r"""A single-task exact GP model using a heteroskedastic noise model.
365365
366366
This model differs from `SingleTaskGP` in that noise levels are provided
@@ -423,7 +423,12 @@ def __init__(
423423
input_transform=input_transform,
424424
)
425425
likelihood = _GaussianLikelihoodBase(HeteroskedasticNoise(noise_model))
426-
super().__init__(
426+
# This is hacky -- this class used to inherit from SingleTaskGP, but it
427+
# shouldn't so this is a quick fix to enable getting rid of that
428+
# inheritance
429+
SingleTaskGP.__init__(
430+
# pyre-fixme[6]: Incompatible parameter type
431+
self,
427432
train_X=train_X,
428433
train_Y=train_Y,
429434
likelihood=likelihood,
@@ -437,15 +442,17 @@ def __init__(
437442
self.outcome_transform = outcome_transform
438443
self.to(train_X)
439444

440-
# TODO: HeteroskedasticSingleTaskGP should not be a subclass of
441-
# SingleTaskGP because it can't function the way a SingleTaskGP does
442445
# pyre-fixme[15]: Inconsistent override
443446
def condition_on_observations(self, *_, **__) -> NoReturn:
444447
raise NotImplementedError
445448

446-
def fantasize(self, *_, **__) -> NoReturn:
447-
raise NotImplementedError
448-
449449
# pyre-fixme[15]: Inconsistent override
450450
def subset_output(self, idcs) -> NoReturn:
451451
raise NotImplementedError
452+
453+
def forward(self, x: Tensor) -> MultivariateNormal:
454+
if self.training:
455+
x = self.transform_inputs(x)
456+
mean_x = self.mean_module(x)
457+
covar_x = self.covar_module(x)
458+
return MultivariateNormal(mean_x, covar_x)

botorch/models/multitask.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from botorch.acquisition.objective import PosteriorTransform
2929
from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL
3030
from botorch.models.gpytorch import GPyTorchModel, MultiTaskGPyTorchModel
31+
from botorch.models.model import FantasizeMixin
3132
from botorch.models.transforms.input import InputTransform
3233
from botorch.models.transforms.outcome import OutcomeTransform
3334
from botorch.posteriors.multitask import MultitaskGPPosterior
@@ -71,7 +72,7 @@
7172
from torch import Tensor
7273

7374

74-
class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel):
75+
class MultiTaskGP(ExactGP, MultiTaskGPyTorchModel, FantasizeMixin):
7576
r"""Multi-Task GP model using an ICM kernel, inferring observation noise.
7677
7778
Multi-task exact GP that uses a simple ICM kernel. Can be single-output or
@@ -377,7 +378,7 @@ def __init__(
377378
self.to(train_X)
378379

379380

380-
class KroneckerMultiTaskGP(ExactGP, GPyTorchModel):
381+
class KroneckerMultiTaskGP(ExactGP, GPyTorchModel, FantasizeMixin):
381382
"""Multi-task GP with Kronecker structure, using an ICM kernel.
382383
383384
This model assumes the "block design" case, i.e., it requires that all tasks

test/models/test_fully_bayesian_multitask.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ def test_raises(self):
159159
task_feature=4,
160160
)
161161
train_X, train_Y, train_Yvar, model = self._get_data_and_model(**tkwargs)
162+
sampler = IIDNormalSampler(num_samples=2)
163+
with self.assertRaisesRegex(
164+
NotImplementedError, "Fantasize is not implemented!"
165+
):
166+
model.fantasize(
167+
X=torch.cat(
168+
[torch.rand(5, 4, **tkwargs), torch.ones(5, 1, **tkwargs)], dim=1
169+
),
170+
sampler=sampler,
171+
)
162172

163173
# Make sure an exception is raised if the model has not been fitted
164174
not_fitted_error_msg = (

test/models/test_gp_regression.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,12 +454,23 @@ def _get_model_and_data(
454454
model = HeteroskedasticSingleTaskGP(**model_kwargs)
455455
return model, model_kwargs
456456

457-
def test_custom_init(self):
458-
pass
457+
def test_custom_init(self) -> None:
458+
"""
459+
This test exists because `TestHeteroskedasticSingleTaskGP` inherits from
460+
`TestSingleTaskGP`, which has a `test_custom_init` method that isn't relevant
461+
for `TestHeteroskedasticSingleTaskGP`.
462+
"""
459463

460464
def test_gp(self):
461465
super().test_gp(double_only=True)
462466

467+
def test_fantasize(self) -> None:
468+
"""
469+
This test exists because `TestHeteroskedasticSingleTaskGP` inherits from
470+
`TestSingleTaskGP`, which has a `fantasize` method that isn't relevant
471+
for `TestHeteroskedasticSingleTaskGP`.
472+
"""
473+
463474
def test_heteroskedastic_likelihood(self):
464475
for batch_shape, m, dtype in itertools.product(
465476
(torch.Size(), torch.Size([2])), (1, 2), (torch.float, torch.double)
@@ -480,10 +491,6 @@ def test_condition_on_observations(self):
480491
with self.assertRaises(NotImplementedError):
481492
super().test_condition_on_observations()
482493

483-
def test_fantasize(self):
484-
with self.assertRaises(NotImplementedError):
485-
super().test_fantasize()
486-
487494
def test_subset_model(self):
488495
with self.assertRaises(NotImplementedError):
489496
super().test_subset_model()

0 commit comments

Comments
 (0)