Skip to content

Commit 98eb060

Browse files
dme65facebook-github-bot
authored andcommitted
Add construct_inputs to SAASBO (#1136)
Summary: Pull Request resolved: #1136 Needed for BotAx support. Reviewed By: saitcakmak Differential Revision: D35164975 fbshipit-source-id: 2f496508c3f37a00c49fcbb971cf868740c15a22
1 parent dc055be commit 98eb060

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

botorch/models/fully_bayesian.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from botorch.models.utils import validate_input_scaling
3333
from botorch.posteriors.fully_bayesian import FullyBayesianPosterior
3434
from botorch.sampling.samplers import MCSampler
35+
from botorch.utils.containers import TrainingData
3536
from gpytorch.constraints import GreaterThan
3637
from gpytorch.distributions.multivariate_normal import MultivariateNormal
3738
from gpytorch.kernels import MaternKernel, ScaleKernel
@@ -419,3 +420,19 @@ def posterior(
419420
marginalize_over_mcmc_samples=marginalize_over_mcmc_samples,
420421
)
421422
return posterior
423+
424+
@classmethod
425+
def construct_inputs(
426+
cls, training_data: TrainingData, **kwargs: Any
427+
) -> Dict[str, Any]:
428+
r"""Construct kwargs for the `Model` from `TrainingData` and other options.
429+
430+
Args:
431+
training_data: `TrainingData` container with data for single outcome
432+
or for multiple outcomes for batched multi-output case.
433+
**kwargs: None expected for this class.
434+
"""
435+
inputs = {"train_X": training_data.X, "train_Y": training_data.Y}
436+
if training_data.Yvar is not None:
437+
inputs["train_Yvar"] = training_data.Yvar
438+
return inputs

test/models/test_fully_bayesian.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from botorch.models.transforms import Normalize, Standardize
3636
from botorch.posteriors import FullyBayesianPosterior
3737
from botorch.sampling.samplers import IIDNormalSampler
38+
from botorch.utils.containers import TrainingData
3839
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
3940
NondominatedPartitioning,
4041
)
@@ -360,3 +361,24 @@ def test_load_samples(self):
360361
train_Yvar.clamp(MIN_INFERRED_NOISE_LEVEL),
361362
)
362363
)
364+
365+
def test_construct_inputs(self):
366+
for infer_noise, dtype in itertools.product(
367+
(True, False), (torch.float, torch.double)
368+
):
369+
tkwargs = {"device": self.device, "dtype": dtype}
370+
train_X, train_Y, train_Yvar, model = self._get_data_and_model(
371+
infer_noise=infer_noise, **tkwargs
372+
)
373+
training_data = TrainingData.from_block_design(
374+
X=train_X,
375+
Y=train_Y,
376+
Yvar=train_Yvar,
377+
)
378+
data_dict = model.construct_inputs(training_data)
379+
if infer_noise:
380+
self.assertTrue("train_Yvar" not in data_dict)
381+
else:
382+
self.assertTrue(torch.equal(data_dict["train_Yvar"], train_Yvar))
383+
self.assertTrue(torch.equal(data_dict["train_X"], train_X))
384+
self.assertTrue(torch.equal(data_dict["train_Y"], train_Y))

0 commit comments

Comments
 (0)