Skip to content

Commit 0fb00ef

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Support SAAS ensemble models in RFFs (#1530)
Summary: Pull Request resolved: #1530 RFFs currently support batched models but not batched evaluation (without the model batch dimension unsqueezed) of RFFs constructed from batched models. This prevents their use in acquisition functions. The changes here make it possible to construct an RFF from a SAAS ensemble model and use it in acquisition functions. Reviewed By: Balandat Differential Revision: D41569212 fbshipit-source-id: 6d371a0c4b1ab3be24f7e790b3527cc5f29e6808
1 parent 6e8df86 commit 0fb00ef

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

botorch/utils/gp_sampling.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from botorch.models.model_list_gp_regression import ModelListGP
1818
from botorch.models.multitask import MultiTaskGP
1919
from botorch.utils.sampling import manual_seed
20+
from botorch.utils.transforms import is_fully_bayesian
2021
from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel
2122
from linear_operator.utils.cholesky import psd_safe_cholesky
2223
from torch import Tensor
@@ -228,7 +229,20 @@ def forward(self, X: Tensor) -> Tensor:
228229
a `batch_shape`, the output `batch_shape` will be
229230
`(sample_shape) x (kernel_batch_shape)`.
230231
"""
231-
self._check_forward_X_shape_compatibility(X)
232+
try:
233+
self._check_forward_X_shape_compatibility(X)
234+
except ValueError as e:
235+
# A workaround to support batched SAAS models.
236+
# TODO: Support batch evaluation of multi-sample RFFs as well.
237+
# Multi-sample RFFs have input batch as the 0-th dimension,
238+
# which is different than other posteriors which would have
239+
# the sample shape as the 0-th dimension.
240+
if len(self.kernel_batch_shape) == 1:
241+
X = X.unsqueeze(-3)
242+
self._check_forward_X_shape_compatibility(X)
243+
else:
244+
raise e
245+
232246
# X is of shape (additional_batch_shape) x (sample_shape)
233247
# x (kernel_batch_shape) x n x d.
234248
# Weights is of shape (sample_shape) x (kernel_batch_shape) x d x num_rff.
@@ -489,6 +503,7 @@ def get_gp_samples(
489503
models[m].outcome_transform = _octf
490504
if _intf is not None:
491505
base_gp_samples.models[m].input_transform = _intf
506+
base_gp_samples.is_fully_bayesian = is_fully_bayesian(model=model)
492507
return base_gp_samples
493508
elif n_samples > 1:
494509
base_gp_samples = get_deterministic_model_multi_samples(
@@ -507,4 +522,5 @@ def get_gp_samples(
507522
if octf is not None:
508523
base_gp_samples.outcome_transform = octf
509524
model.outcome_transform = octf
525+
base_gp_samples.is_fully_bayesian = is_fully_bayesian(model=model)
510526
return base_gp_samples

botorch/utils/transforms.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,17 @@ def is_fully_bayesian(model: Model) -> bool:
192192
SaasFullyBayesianMultiTaskGP,
193193
]
194194

195-
if any(isinstance(model, m_cls) for m_cls in full_bayesian_model_cls):
195+
if any(
196+
isinstance(model, m_cls) or getattr(model, "is_fully_bayesian", False)
197+
for m_cls in full_bayesian_model_cls
198+
):
196199
return True
197200
elif isinstance(model, ModelList):
198201
for m in model.models:
199-
if any(isinstance(m, m_cls) for m_cls in full_bayesian_model_cls):
202+
if any(
203+
isinstance(m, m_cls) or getattr(model, "is_fully_bayesian", False)
204+
for m_cls in full_bayesian_model_cls
205+
):
200206
return True
201207
elif isinstance(m, ModelListGP) and any(
202208
isinstance(m_sub, m_cls)

test/utils/test_gp_sampling.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from botorch.models.converter import batched_to_model_list
1313
from botorch.models.deterministic import DeterministicModel
14+
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
1415
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
1516
from botorch.models.model import ModelList
1617
from botorch.models.multitask import MultiTaskGP
@@ -26,6 +27,7 @@
2627
RandomFourierFeatures,
2728
)
2829
from botorch.utils.testing import BotorchTestCase
30+
from botorch.utils.transforms import is_fully_bayesian
2931
from gpytorch.kernels import MaternKernel, PeriodicKernel, RBFKernel, ScaleKernel
3032
from torch.distributions import MultivariateNormal
3133

@@ -661,3 +663,31 @@ def test_with_fixed_noise(self):
661663
torch.Size([2, 1]) if n_samples == 1 else torch.Size([n_samples, 2, 1])
662664
)
663665
self.assertEqual(samples.shape, expected_shape)
666+
667+
def test_with_saas_models(self):
668+
# Construct a SAAS model.
669+
tkwargs = {"dtype": torch.double, "device": self.device}
670+
num_samples = 4
671+
model = SaasFullyBayesianSingleTaskGP(
672+
train_X=torch.rand(10, 4, **tkwargs), train_Y=torch.randn(10, 1, **tkwargs)
673+
)
674+
mcmc_samples = {
675+
"lengthscale": torch.rand(num_samples, 1, 4, **tkwargs),
676+
"outputscale": torch.rand(num_samples, **tkwargs),
677+
"mean": torch.randn(num_samples, **tkwargs),
678+
"noise": torch.rand(num_samples, 1, **tkwargs),
679+
}
680+
model.load_mcmc_samples(mcmc_samples)
681+
# Test proper setup & sampling support.
682+
gp_samples = get_gp_samples(
683+
model=model,
684+
num_outputs=1,
685+
n_samples=1,
686+
)
687+
self.assertTrue(is_fully_bayesian(gp_samples))
688+
# Non-batch evaluation.
689+
samples = gp_samples(torch.rand(2, 4, **tkwargs))
690+
self.assertEqual(samples.shape, torch.Size([4, 2, 1]))
691+
# Batch evaluation.
692+
samples = gp_samples(torch.rand(5, 2, 4, **tkwargs))
693+
self.assertEqual(samples.shape, torch.Size([5, 4, 2, 1]))

0 commit comments

Comments
 (0)