Skip to content

Commit e3b5a52

Browse files
qingfeng10facebook-github-bot
authored andcommitted
add construct inputs for SaasFullyBayesianMultiTaskGP (#1203)
Summary: Pull Request resolved: #1203 as title. D35573439 is too big. I make a separate diff to add construct_inputs for using `SaasFullyBayesianMultiTaskGP`. This overwritten is necessary in order to use the model in BotAx Reviewed By: dme65 Differential Revision: D36136031 fbshipit-source-id: c6aee985f7f5468252909306cf4f577818debcb0
1 parent 90f8ca4 commit e3b5a52

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

botorch/models/fully_bayesian_multitask.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from botorch.models.transforms.outcome import OutcomeTransform
2626
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior, MCMC_DIM
2727
from botorch.sampling.samplers import MCSampler
28+
from botorch.utils.datasets import SupervisedDataset
2829
from gpytorch.distributions.multivariate_normal import MultivariateNormal
2930
from gpytorch.kernels import MaternKernel
3031
from gpytorch.kernels.kernel import Kernel
@@ -372,3 +373,26 @@ def forward(self, X: Tensor) -> MultivariateNormal:
372373
covar_i = self.task_covar_module(latent_features)
373374
covar = covar_x.mul(covar_i)
374375
return MultivariateNormal(mean_x, covar)
376+
377+
@classmethod
378+
def construct_inputs(
379+
cls,
380+
training_data: Dict[str, SupervisedDataset],
381+
task_feature: int,
382+
rank: Optional[int] = None,
383+
**kwargs: Any,
384+
) -> Dict[str, Any]:
385+
r"""Construct `Model` keyword arguments from dictionary of `SupervisedDataset`.
386+
387+
Args:
388+
training_data: Dictionary of `SupervisedDataset`.
389+
task_feature: Column index of embedded task indicator features. For details,
390+
see `parse_training_data`.
391+
rank: The rank of the cross-task covariance matrix.
392+
"""
393+
394+
inputs = super().construct_inputs(
395+
training_data=training_data, task_feature=task_feature, rank=rank, **kwargs
396+
)
397+
inputs.pop("task_covar_prior")
398+
return inputs

test/models/test_fully_bayesian_multitask.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
4747
from gpytorch.means import ConstantMean
4848

49+
from .test_multitask import _gen_fixed_noise_model_and_data
50+
4951

5052
class TestFullyBayesianMultiTaskGP(BotorchTestCase):
5153
def _get_data_and_model(self, task_rank: Optional[int] = 1, **tkwargs):
@@ -509,3 +511,33 @@ def test_load_samples(self):
509511
mcmc_samples["latent_features"],
510512
)
511513
)
514+
515+
def test_construct_inputs(self):
516+
for dtype in [torch.float, torch.double]:
517+
tkwargs = {"device": self.device, "dtype": dtype}
518+
task_feature = 0
519+
520+
(
521+
_,
522+
datasets,
523+
(train_X, train_Y, train_Yvar),
524+
) = _gen_fixed_noise_model_and_data(task_feature=task_feature, **tkwargs)
525+
526+
model = SaasFullyBayesianMultiTaskGP(
527+
train_X=train_X,
528+
train_Y=train_Y,
529+
train_Yvar=train_Yvar,
530+
task_feature=task_feature,
531+
)
532+
533+
data_dict = model.construct_inputs(
534+
datasets,
535+
task_feature=task_feature,
536+
rank=1,
537+
)
538+
self.assertTrue(torch.equal(data_dict["train_X"], train_X))
539+
self.assertTrue(torch.equal(data_dict["train_Y"], train_Y))
540+
self.assertTrue(torch.allclose(data_dict["train_Yvar"], train_Yvar))
541+
self.assertEqual(data_dict["task_feature"], task_feature)
542+
self.assertEqual(data_dict["rank"], 1)
543+
self.assertTrue("task_covar_prior" not in data_dict)

0 commit comments

Comments
 (0)