Skip to content

Commit 2cc41dc

Browse files
David Erikssonmeta-codesync[bot]
authored andcommitted
Add construct_inputs to EnsembleMapSaasSingleTaskGP (#3040)
Summary: Pull Request resolved: #3040 We need this to pass in `num_taus` from MBM. Reviewed By: Balandat Differential Revision: D83442381 fbshipit-source-id: 6a83e2cd79d38efcb3605a34608cc00f1a222902
1 parent 1f55694 commit 2cc41dc

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

botorch/models/map_saas.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
)
1717
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM
1818
from botorch.utils.constraints import LogTransformedInterval
19+
from botorch.utils.containers import BotorchContainer
20+
from botorch.utils.datasets import SupervisedDataset
1921
from botorch.utils.types import _DefaultType, DEFAULT
2022
from gpytorch.constraints import Interval
2123
from gpytorch.kernels import AdditiveKernel, Kernel, MaternKernel, ScaleKernel
@@ -561,3 +563,19 @@ def posterior(
561563
**kwargs,
562564
)
563565
return GaussianMixturePosterior(distribution=posterior.distribution)
566+
567+
@classmethod
568+
def construct_inputs(
569+
cls,
570+
training_data: SupervisedDataset,
571+
*,
572+
num_taus: int = 4,
573+
) -> dict[str, BotorchContainer | Tensor]:
574+
r"""Construct `Model` keyword arguments from a dict of `SupervisedDataset`.
575+
576+
Args:
577+
training_data: A `SupervisedDataset` containing the training data.
578+
num_taus: Number of taus to use in the ensemble (4 if omitted).
579+
"""
580+
base_inputs = super().construct_inputs(training_data=training_data)
581+
return {**base_inputs, "num_taus": num_taus}

test/models/test_map_saas.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from botorch.posteriors.gpytorch import GPyTorchPosterior
3535
from botorch.test_utils.mock import mock_optimize
3636
from botorch.utils.constraints import LogTransformedInterval
37+
from botorch.utils.datasets import SupervisedDataset
3738
from botorch.utils.testing import BotorchTestCase
3839
from gpytorch.constraints import Interval
3940
from gpytorch.kernels import AdditiveKernel, MaternKernel, ScaleKernel
@@ -568,6 +569,56 @@ def test_ensemble_map_saas_validation(self) -> None:
568569
train_X=torch.rand(2, 5, 3), train_Y=torch.rand(2, 5, 1)
569570
)
570571

572+
def test_ensemble_map_saas_construct_inputs(self) -> None:
573+
"""Test the construct_inputs class method for EnsembleMapSaasSingleTaskGP."""
574+
575+
train_X, train_Y, _ = self._get_data()
576+
training_data = SupervisedDataset(
577+
X=train_X, Y=train_Y, feature_names=["x1", "x2", "x3"], outcome_names=["y"]
578+
)
579+
580+
# Test with default num_taus
581+
inputs_default = EnsembleMapSaasSingleTaskGP.construct_inputs(
582+
training_data=training_data
583+
)
584+
self.assertIn("num_taus", inputs_default)
585+
self.assertEqual(inputs_default["num_taus"], 4)
586+
self.assertIn("train_X", inputs_default)
587+
self.assertIn("train_Y", inputs_default)
588+
self.assertAllClose(inputs_default["train_X"], train_X)
589+
self.assertAllClose(inputs_default["train_Y"], train_Y)
590+
591+
# Test with custom num_taus
592+
custom_num_taus = 6
593+
inputs_custom = EnsembleMapSaasSingleTaskGP.construct_inputs(
594+
training_data=training_data, num_taus=custom_num_taus
595+
)
596+
self.assertIn("num_taus", inputs_custom)
597+
self.assertEqual(inputs_custom["num_taus"], custom_num_taus)
598+
self.assertIn("train_X", inputs_custom)
599+
self.assertIn("train_Y", inputs_custom)
600+
self.assertAllClose(inputs_custom["train_X"], train_X)
601+
self.assertAllClose(inputs_custom["train_Y"], train_Y)
602+
603+
# Test with train_Yvar in the dataset
604+
train_Yvar = 0.1 * torch.rand_like(train_Y)
605+
training_data_with_yvar = SupervisedDataset(
606+
X=train_X,
607+
Y=train_Y,
608+
Yvar=train_Yvar,
609+
feature_names=["x1", "x2", "x3"],
610+
outcome_names=["y"],
611+
)
612+
inputs_with_yvar = EnsembleMapSaasSingleTaskGP.construct_inputs(
613+
training_data=training_data_with_yvar, num_taus=3
614+
)
615+
self.assertIn("train_Yvar", inputs_with_yvar)
616+
self.assertAllClose(inputs_with_yvar["train_Yvar"], train_Yvar)
617+
self.assertEqual(inputs_with_yvar["num_taus"], 3)
618+
model_with_yvar = EnsembleMapSaasSingleTaskGP(**inputs_with_yvar)
619+
self.assertIsInstance(model_with_yvar, EnsembleMapSaasSingleTaskGP)
620+
self.assertEqual(model_with_yvar.batch_shape, torch.Size([3]))
621+
571622

572623
class TestAdditiveMapSaasSingleTaskGP(BotorchTestCase):
573624
def _get_data_and_model(

0 commit comments

Comments
 (0)