|
38 | 38 | import pyro |
39 | 39 | import torch |
40 | 40 | from botorch.acquisition.objective import PosteriorTransform |
41 | | -from botorch.models.gp_regression import SingleTaskGP |
| 41 | +from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel |
42 | 42 | from botorch.models.transforms.input import InputTransform |
43 | 43 | from botorch.models.transforms.outcome import OutcomeTransform |
44 | 44 | from botorch.models.utils import validate_input_scaling |
|
54 | 54 | from gpytorch.likelihoods.likelihood import Likelihood |
55 | 55 | from gpytorch.means.constant_mean import ConstantMean |
56 | 56 | from gpytorch.means.mean import Mean |
| 57 | +from gpytorch.models.exact_gp import ExactGP |
57 | 58 | from torch import Tensor |
58 | 59 |
|
59 | 60 | MIN_INFERRED_NOISE_LEVEL = 1e-6 |
@@ -294,7 +295,7 @@ def load_mcmc_samples( |
294 | 295 | return mean_module, covar_module, likelihood |
295 | 296 |
|
296 | 297 |
|
297 | | -class SaasFullyBayesianSingleTaskGP(SingleTaskGP): |
| 298 | +class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel): |
298 | 299 | r"""A fully Bayesian single-task GP model with the SAAS prior. |
299 | 300 |
|
300 | 301 | This model assumes that the inputs have been normalized to [0, 1]^d and that |
@@ -364,7 +365,7 @@ def __init__( |
364 | 365 | train_Yvar = train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL) |
365 | 366 |
|
366 | 367 | X_tf, Y_tf, _ = self._transform_tensor_args(X=train_X, Y=train_Y) |
367 | | - super(SingleTaskGP, self).__init__( |
| 368 | + super().__init__( |
368 | 369 | train_inputs=X_tf, train_targets=Y_tf, likelihood=GaussianLikelihood() |
369 | 370 | ) |
370 | 371 | self.mean_module = None |
@@ -473,9 +474,19 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): |
473 | 474 | super().load_state_dict(state_dict=state_dict, strict=strict) |
474 | 475 |
|
475 | 476 | def forward(self, X: Tensor) -> MultivariateNormal: |
| 477 | + """ |
| 478 | + Unlike in other classes' `forward` methods, there is no `if self.training` |
| 479 | + block, because it ought to be unreachable: If `self.train()` has been called, |
| 480 | + then `self.covar_module` will be None, `check_if_fitted()` will fail, and the |
| 481 | + rest of this method will not run. |
| 482 | + """ |
476 | 483 | self._check_if_fitted() |
477 | | - return super().forward(X.unsqueeze(MCMC_DIM)) |
| 484 | + x = X.unsqueeze(MCMC_DIM) |
| 485 | + mean_x = self.mean_module(x) |
| 486 | + covar_x = self.covar_module(x) |
| 487 | + return MultivariateNormal(mean_x, covar_x) |
478 | 488 |
|
| 489 | + # pyre-ignore[14]: Inconsistent override |
479 | 490 | def posterior( |
480 | 491 | self, |
481 | 492 | X: Tensor, |
|
0 commit comments