Skip to content

Commit 6e01234

Browse files
dme65facebook-github-bot
authored andcommitted
Support SaasFullyBayesianSingleTaskGP in prune_inferior_points (meta-pytorch#1260)
Summary: Pull Request resolved: meta-pytorch#1260 See title Reviewed By: saitcakmak Differential Revision: D37091706 fbshipit-source-id: c091683cdb52a5d2f21f406d98589f94d9fa218b
1 parent 0188831 commit 6e01234

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed

botorch/acquisition/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@
2626
)
2727
from botorch.exceptions.errors import UnsupportedError
2828
from botorch.exceptions.warnings import SamplingWarning
29+
from botorch.models.fully_bayesian import MCMC_DIM
2930
from botorch.models.model import Model
3031
from botorch.sampling.samplers import IIDNormalSampler, MCSampler, SobolQMCNormalSampler
3132
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
3233
FastNondominatedPartitioning,
3334
NondominatedPartitioning,
3435
)
36+
from botorch.utils.transforms import is_fully_bayesian
3537
from torch import Tensor
3638
from torch.quasirandom import SobolEngine
3739

@@ -320,6 +322,10 @@ def prune_inferior_points(
320322
with `N_nz` the number of points in `X` that have non-zero (empirical,
321323
under `num_samples` samples) probability of being the best point.
322324
"""
325+
if marginalize_dim is None and is_fully_bayesian(model):
326+
# TODO: Properly deal with marginalizing fully Bayesian models
327+
marginalize_dim = MCMC_DIM
328+
323329
if X.ndim > 2:
324330
# TODO: support batched inputs (req. dealing with ragged tensors)
325331
raise UnsupportedError(

botorch/utils/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,12 @@ def is_fully_bayesian(model: Model) -> bool:
183183
Returns:
184184
True if at least one model is a `SaasFullyBayesianSingleTaskGP`
185185
"""
186-
from botorch.models import ModelList
186+
from botorch.models import ModelList, ModelListGP
187187
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
188188

189189
if isinstance(model, SaasFullyBayesianSingleTaskGP):
190190
return True
191-
elif isinstance(model, ModelList) and any(
191+
elif isinstance(model, (ModelList, ModelListGP)) and any(
192192
isinstance(m, SaasFullyBayesianSingleTaskGP) for m in model.models
193193
):
194194
return True

test/models/test_fully_bayesian.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
qExpectedHypervolumeImprovement,
2828
qNoisyExpectedHypervolumeImprovement,
2929
)
30+
from botorch.acquisition.utils import prune_inferior_points
3031
from botorch.models import ModelList, ModelListGP
3132
from botorch.models.deterministic import GenericDeterministicModel
3233
from botorch.models.fully_bayesian import (
@@ -423,6 +424,10 @@ def test_acquisition_functions(self):
423424
test_X = torch.rand(*batch_shape, 1, 4, **tkwargs)
424425
self.assertEqual(acqf(test_X).shape, torch.Size(batch_shape))
425426

427+
# Test prune_inferior_points
428+
X_pruned = prune_inferior_points(model=model, X=train_X)
429+
self.assertTrue(X_pruned.ndim == 2 and X_pruned.shape[-1] == 4)
430+
426431
# Test prune_inferior_points_multi_objective
427432
for model_list in [ModelListGP(model, model), ModelList(deterministic, model)]:
428433
X_pruned = prune_inferior_points_multi_objective(

test/utils/test_transforms.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,19 @@
88
from typing import Any
99

1010
import torch
11+
from botorch.models import (
12+
GenericDeterministicModel,
13+
ModelList,
14+
ModelListGP,
15+
SaasFullyBayesianSingleTaskGP,
16+
SingleTaskGP,
17+
)
1118
from botorch.models.model import Model
1219
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
1320
from botorch.utils.transforms import (
1421
_verify_output_shape,
1522
concatenate_pending_points,
23+
is_fully_bayesian,
1624
match_batch_shape,
1725
normalize,
1826
normalize_indices,
@@ -278,3 +286,23 @@ def test_squeeze_last_dim(self):
278286
Y_squeezed = squeeze_last_dim(Y=Y)
279287
self.assertTrue(any(issubclass(w.category, DeprecationWarning) for w in ws))
280288
self.assertTrue(torch.equal(Y_squeezed, Y.squeeze(-1)))
289+
290+
291+
class TestIsFullyBayesian(BotorchTestCase):
292+
def test_is_fully_bayesian(self):
293+
X, Y = torch.rand(3, 2), torch.randn(3, 1)
294+
saas = SaasFullyBayesianSingleTaskGP(train_X=X, train_Y=Y)
295+
vanilla_gp = SingleTaskGP(train_X=X, train_Y=Y)
296+
deterministic = GenericDeterministicModel(f=lambda x: x)
297+
# Single model
298+
self.assertTrue(is_fully_bayesian(model=saas))
299+
self.assertFalse(is_fully_bayesian(model=vanilla_gp))
300+
self.assertFalse(is_fully_bayesian(model=deterministic))
301+
# ModelListGP
302+
self.assertTrue(is_fully_bayesian(model=ModelListGP(saas, saas)))
303+
self.assertTrue(is_fully_bayesian(model=ModelListGP(saas, vanilla_gp)))
304+
self.assertFalse(is_fully_bayesian(model=ModelListGP(vanilla_gp, vanilla_gp)))
305+
# ModelList
306+
self.assertTrue(is_fully_bayesian(model=ModelList(saas, saas)))
307+
self.assertTrue(is_fully_bayesian(model=ModelList(saas, deterministic)))
308+
self.assertFalse(is_fully_bayesian(model=ModelList(vanilla_gp, deterministic)))

0 commit comments

Comments
 (0)