Skip to content

Commit 90f8ca4

Browse files
qingfeng10facebook-github-bot
authored andcommitted
Implement multitask SAAS (#1181)
Summary: Pull Request resolved: #1181 Move multitask SAAS GP into BoTorch Reviewed By: dme65 Differential Revision: D35573439 fbshipit-source-id: 92bef5e423e48092f8848562cf6495729c403c8f
1 parent 57decad commit 90f8ca4

File tree

9 files changed

+928
-10
lines changed

9 files changed

+928
-10
lines changed

botorch/fit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
import logging
1414
import warnings
1515
from copy import deepcopy
16-
from typing import Any, Callable
16+
from typing import Any, Callable, Union
1717

1818
from botorch.exceptions.errors import UnsupportedError
1919
from botorch.exceptions.warnings import BotorchWarning, OptimizationWarning
2020
from botorch.models.converter import batched_to_model_list, model_list_to_batched
2121
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
22+
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
23+
2224
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
2325
from botorch.optim.fit import fit_gpytorch_scipy
2426
from botorch.optim.utils import sample_all_priors
@@ -157,7 +159,7 @@ def fit_gpytorch_model(
157159

158160

159161
def fit_fully_bayesian_model_nuts(
160-
model: SaasFullyBayesianSingleTaskGP,
162+
model: Union[SaasFullyBayesianSingleTaskGP, SaasFullyBayesianMultiTaskGP],
161163
max_tree_depth: int = 6,
162164
warmup_steps: int = 512,
163165
num_samples: int = 256,

botorch/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
PosteriorMeanModel,
1616
)
1717
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
18+
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
19+
1820
from botorch.models.gp_regression import (
1921
FixedNoiseGP,
2022
HeteroskedasticSingleTaskGP,
@@ -39,6 +41,7 @@
3941
"FixedNoiseGP",
4042
"FixedNoiseMultiTaskGP",
4143
"SaasFullyBayesianSingleTaskGP",
44+
"SaasFullyBayesianMultiTaskGP",
4245
"GenericDeterministicModel",
4346
"HeteroskedasticSingleTaskGP",
4447
"HigherOrderGP",

botorch/models/fully_bayesian.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ class SaasPyroModel(PyroModel):
145145
`covar_module`).
146146
"""
147147

148+
def set_inputs(
149+
self, train_X: Tensor, train_Y: Tensor, train_Yvar: Optional[Tensor] = None
150+
):
151+
super().set_inputs(train_X, train_Y, train_Yvar)
152+
self.ard_num_dims = self.train_X.shape[-1]
153+
148154
def sample(self) -> None:
149155
r"""Sample from the SAAS model.
150156
@@ -155,7 +161,7 @@ def sample(self) -> None:
155161
outputscale = self.sample_outputscale(concentration=2.0, rate=0.15, **tkwargs)
156162
mean = self.sample_mean(**tkwargs)
157163
noise = self.sample_noise(**tkwargs)
158-
lengthscale = self.sample_lengthscale(dim=self.train_X.shape[-1], **tkwargs)
164+
lengthscale = self.sample_lengthscale(dim=self.ard_num_dims, **tkwargs)
159165
k = matern52_kernel(X=self.train_X, lengthscale=lengthscale)
160166
k = outputscale * k + noise * torch.eye(self.train_X.shape[0], **tkwargs)
161167
pyro.sample(
@@ -252,7 +258,7 @@ def load_mcmc_samples(
252258
mean_module = ConstantMean(batch_shape=batch_shape).to(**tkwargs)
253259
covar_module = ScaleKernel(
254260
base_kernel=MaternKernel(
255-
ard_num_dims=self.train_X.shape[-1],
261+
ard_num_dims=self.ard_num_dims,
256262
batch_shape=batch_shape,
257263
),
258264
batch_shape=batch_shape,

0 commit comments

Comments
 (0)