Skip to content

Commit 4b2363d

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add EnsembleMapSaasGP (#3035)
Summary: Introduces a simple `EnsembleMapSaasGP` model that will replace `get_fitted_map_saas_ensemble` (which fits individual non-ensemble models and combines them into a fully Bayesian GP). The model internally is a batched `ExactGP`, which behaves just like a multi-output `SingleTaskGP`. The `posterior` method is overwritten to produce a `MixtureGaussianPosterior`, which retains the old behavior of the ensemble model. The benefit of this model class is that it can be fit just like any other GP model, using `ExactMarginalLogLikelihood` and `fit_gpytorch_mll`. As such, it is fully compatible with Ax's MBM setup (as long as `allow_batched_models=False`). Reviewed By: sdaulton Differential Revision: D83701925
1 parent ce9cc6b commit 4b2363d

File tree

3 files changed

+207
-10
lines changed

3 files changed

+207
-10
lines changed

botorch/models/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
)
1717
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
1818
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
19-
2019
from botorch.models.gp_regression import SingleTaskGP
2120
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
2221
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
2322
from botorch.models.higher_order_gp import HigherOrderGP
24-
25-
from botorch.models.map_saas import add_saas_prior, AdditiveMapSaasSingleTaskGP
23+
from botorch.models.map_saas import (
24+
add_saas_prior,
25+
AdditiveMapSaasSingleTaskGP,
26+
EnsembleMapSaasGP,
27+
)
2628
from botorch.models.model import ModelList
2729
from botorch.models.model_list_gp_regression import ModelListGP
2830
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
@@ -34,6 +36,7 @@
3436
"AffineDeterministicModel",
3537
"AffineFidelityCostModel",
3638
"ApproximateGPyTorchModel",
39+
"EnsembleMapSaasGP",
3740
"SaasFullyBayesianSingleTaskGP",
3841
"SaasFullyBayesianMultiTaskGP",
3942
"GenericDeterministicModel",

botorch/models/map_saas.py

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import Any
67

78
import torch
9+
from botorch.acquisition.objective import PosteriorTransform
810
from botorch.exceptions import UnsupportedError
911
from botorch.models.gp_regression import SingleTaskGP
1012
from botorch.models.transforms.input import InputTransform
1113
from botorch.models.transforms.outcome import OutcomeTransform
1214
from botorch.models.utils.gpytorch_modules import (
1315
get_gaussian_likelihood_with_lognormal_prior,
1416
)
17+
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM
1518
from botorch.utils.constraints import LogTransformedInterval
1619
from botorch.utils.types import _DefaultType, DEFAULT
1720
from gpytorch.constraints import Interval
@@ -30,12 +33,15 @@
3033
class SaasPriorHelper:
3134
"""Helper class for specifying parameter and setting closures."""
3235

33-
def __init__(self, tau: float | None = None):
36+
def __init__(self, tau: Tensor | float | None = None):
3437
"""Instantiates a new helper object.
3538
3639
Args:
3740
tau: Value of the global shrinkage parameter. If `None`, the tau will be
3841
a free parameter and inferred from the data.
42+
Tau can be a tensor for batched models, like `EnsembleMapSaasGP`,
43+
where each batch has a different sparsity prior. If tau is a tensor,
44+
it must have shape `batch_shape`.
3945
"""
4046
self._tau = torch.as_tensor(tau) if tau is not None else None
4147

@@ -102,10 +108,8 @@ def tau_prior_setting_closure(self, m: Kernel, value: Tensor) -> None:
102108
"""
103109
lb = m.raw_tau_constraint.lower_bound.to(m.raw_tau)
104110
ub = m.raw_tau_constraint.upper_bound.to(m.raw_tau)
105-
m.raw_tau.data.fill_(
106-
m.raw_tau_constraint.inverse_transform(
107-
value.to(m.raw_tau).clamp(lb + EPS, ub - EPS)
108-
).item()
111+
m.raw_tau.data = m.raw_tau_constraint.inverse_transform(
112+
value.to(m.raw_tau).clamp(lb + EPS, ub - EPS)
109113
)
110114

111115

@@ -218,7 +222,7 @@ def get_map_saas_model(
218222
)
219223
# NOTE: need to call `to` to set device and dtype before calling `add_saas_prior`,
220224
# since the SAAS prior contains tensors that are not parameters of the model, and
221-
# terefore not automatically moved to the correct device with a `to` call on the
225+
# therefore not automatically moved to the correct device with a `to` call on the
222226
# model.
223227
base_kernel.to(train_X)
224228
add_saas_prior(base_kernel=base_kernel, tau=tau)
@@ -421,3 +425,139 @@ def __init__(
421425
)
422426
# Make sure that all buffers and parameters have the correct device and dtype
423427
self.to(dtype=train_X.dtype, device=train_X.device)
428+
429+
430+
class EnsembleMapSaasGP(SingleTaskGP):
431+
_is_ensemble = True
432+
433+
def __init__(
434+
self,
435+
train_X: Tensor,
436+
train_Y: Tensor,
437+
train_Yvar: Tensor | None = None,
438+
num_taus: int = 4,
439+
taus: Tensor | None = None,
440+
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
441+
input_transform: InputTransform | None = None,
442+
) -> None:
443+
"""Instantiates an ``EnsembleMapSaasGP``, which is a batched ensemble of
444+
``SingleTaskGP``s with the Matern-5/2 kernel and a SAAS prior. The model is
445+
intended to be trained with ``ExactMarginalLogLikelihood`` and
446+
``fit_gpytorch_mll``. Under the hood, the model is equivalent to a
447+
multi-output ``BatchedMultiOutputGPyTorchModel``, but it produces a
448+
``MixtureGaussiaPosterior``, which leads to ensembling of the model outputs.
449+
450+
Args:
451+
train_X: An `n x d` tensor of training features.
452+
train_Y: An `n x 1` tensor of training observations.
453+
train_Yvar: An optional `n x 1` tensor of observed measurement noise.
454+
num_taus: The number of taus to use (4 if omitted). Each tau is
455+
a sparsity parameter for the corresponding kernel in the ensemble.
456+
taus: An optional tensor of shape `num_taus` containing the taus to use.
457+
If omitted, the taus are sampled from a HalfCauchy(0.1) distribution.
458+
outcome_transform: An outcome transform that is applied to the
459+
training data during instantiation and to the posterior during
460+
inference (that is, the `Posterior` obtained by calling
461+
`.posterior` on the model will be on the original scale). We use a
462+
`Standardize` transform if no `outcome_transform` is specified.
463+
Pass down `None` to use no outcome transform. Note that `.train()` will
464+
be called on the outcome transform during instantiation of the model.
465+
input_transform: An input transform that is applied in the model's
466+
forward pass.
467+
"""
468+
if taus is None:
469+
taus = HalfCauchy(torch.tensor(0.1)).sample([num_taus]).to(train_X)
470+
elif taus.shape != torch.Size([num_taus]):
471+
raise ValueError(
472+
f"Expected taus to be of shape {[num_taus]}. Got {taus.shape=}."
473+
)
474+
if train_Y.shape[-1] != 1:
475+
raise UnsupportedError(
476+
f"EnsembleMapSAASGP only supports single-output. Got {train_Y.shape=}."
477+
)
478+
if train_X.ndim != 2:
479+
raise UnsupportedError(
480+
f"EnsembleMapSAASGP only supports 2D inputs. Got {train_X.ndim=}."
481+
)
482+
# Add batch dimension for ensemble.
483+
train_X = train_X.repeat(num_taus, 1, 1)
484+
train_Y = train_Y.repeat(num_taus, 1, 1)
485+
if train_Yvar is not None:
486+
train_Yvar = train_Yvar.repeat(num_taus, 1, 1)
487+
# Construct the sub-modules.
488+
if input_transform is not None:
489+
with torch.no_grad():
490+
transformed_X = input_transform(train_X)
491+
ard_num_dims = transformed_X.shape[-1]
492+
else:
493+
ard_num_dims = train_X.shape[-1]
494+
batch_shape = train_X.shape[:-2] # This is torch.Size([num_taus]).
495+
mean_module = get_mean_module_with_normal_prior(batch_shape=batch_shape)
496+
base_kernel = MaternKernel(
497+
nu=2.5, ard_num_dims=ard_num_dims, batch_shape=batch_shape
498+
)
499+
# NOTE: need to call `to` to set device and dtype before calling
500+
# `add_saas_prior`, since the SAAS prior contains tensors that are not
501+
# parameters of the model, and therefore not automatically moved to the
502+
# correct device with a `to` call on the model.
503+
base_kernel.to(train_X)
504+
add_saas_prior(base_kernel=base_kernel, tau=taus)
505+
covar_module = ScaleKernel(
506+
base_kernel=base_kernel,
507+
outputscale_constraint=LogTransformedInterval(1e-2, 1e4, initial_value=10),
508+
batch_shape=batch_shape,
509+
)
510+
if train_Yvar is None:
511+
likelihood = get_gaussian_likelihood_with_gamma_prior(
512+
batch_shape=batch_shape
513+
)
514+
else:
515+
likelihood = None
516+
517+
super().__init__(
518+
train_X=train_X,
519+
train_Y=train_Y,
520+
train_Yvar=train_Yvar,
521+
likelihood=likelihood,
522+
covar_module=covar_module,
523+
mean_module=mean_module,
524+
outcome_transform=outcome_transform,
525+
input_transform=input_transform,
526+
)
527+
528+
def posterior(
529+
self,
530+
X: Tensor,
531+
output_indices: list[int] | None = None,
532+
observation_noise: bool = False,
533+
posterior_transform: PosteriorTransform | None = None,
534+
**kwargs: Any,
535+
) -> GaussianMixturePosterior:
536+
r"""Computes the posterior over model outputs at the provided points.
537+
538+
Args:
539+
X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension
540+
of the feature space and `q` is the number of points considered
541+
jointly.
542+
output_indices: A list of indices, corresponding to the outputs over
543+
which to compute the posterior (if the model is multi-output).
544+
Can be used to speed up computation if only a subset of the
545+
model's outputs are required for optimization. If omitted,
546+
computes the posterior over all model outputs.
547+
observation_noise: If True, add the observation noise from the
548+
likelihood to the posterior. If a Tensor, use it directly as the
549+
observation noise (must be of shape `(batch_shape) x q x m`).
550+
posterior_transform: An optional PosteriorTransform.
551+
552+
Returns:
553+
A `GaussianMixturePosterior` object. Includes observation noise
554+
if specified.
555+
"""
556+
posterior = super().posterior(
557+
X=X.unsqueeze(MCMC_DIM),
558+
output_indices=output_indices,
559+
observation_noise=observation_noise,
560+
posterior_transform=posterior_transform,
561+
**kwargs,
562+
)
563+
return GaussianMixturePosterior(distribution=posterior.distribution)

test/models/test_map_saas.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from unittest import mock
1313

1414
import torch
15-
1615
from botorch.exceptions import UnsupportedError
1716
from botorch.fit import (
1817
fit_gpytorch_mll,
@@ -23,14 +22,17 @@
2322
from botorch.models.map_saas import (
2423
add_saas_prior,
2524
AdditiveMapSaasSingleTaskGP,
25+
EnsembleMapSaasGP,
2626
get_additive_map_saas_covar_module,
2727
get_gaussian_likelihood_with_gamma_prior,
2828
get_mean_module_with_normal_prior,
2929
)
3030
from botorch.models.transforms.input import AppendFeatures, FilterFeatures, Normalize
3131
from botorch.models.transforms.outcome import Standardize
3232
from botorch.optim.utils import get_parameters_and_bounds, sample_all_priors
33+
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
3334
from botorch.posteriors.gpytorch import GPyTorchPosterior
35+
from botorch.test_utils.mock import mock_optimize
3436
from botorch.utils.constraints import LogTransformedInterval
3537
from botorch.utils.testing import BotorchTestCase
3638
from gpytorch.constraints import Interval
@@ -510,6 +512,58 @@ def test_batch_model_fitting(self) -> None:
510512
atol=1e-3,
511513
)
512514

515+
@mock_optimize
516+
def test_emsemble_map_saas(self) -> None:
517+
train_X, train_Y, test_X = self._get_data()
518+
d = train_X.shape[-1]
519+
num_taus = 8
520+
for with_options in (False, True):
521+
if with_options:
522+
extra_inputs = {
523+
"train_Yvar": 0.1 * torch.rand_like(train_Y),
524+
"taus": torch.rand(num_taus).to(train_X),
525+
"input_transform": Normalize(d=d),
526+
"outcome_transform": None,
527+
}
528+
else:
529+
extra_inputs = {}
530+
model = EnsembleMapSaasGP(
531+
train_X=train_X, train_Y=train_Y, num_taus=num_taus, **extra_inputs
532+
)
533+
sample_all_priors(model) # Checks that the prior is configured correctly.
534+
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
535+
fit_gpytorch_mll(mll)
536+
self.assertIsInstance(model.covar_module, ScaleKernel)
537+
self.assertIsInstance(model.covar_module.base_kernel, MaternKernel)
538+
self.assertEqual(
539+
model.covar_module.base_kernel.lengthscale.shape,
540+
torch.Size([num_taus, 1, d]),
541+
)
542+
self.assertEqual(model.batch_shape, torch.Size([num_taus]))
543+
posterior = model.posterior(test_X)
544+
self.assertIsInstance(posterior, GaussianMixturePosterior)
545+
if with_options:
546+
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
547+
self.assertIsInstance(model.input_transform, Normalize)
548+
self.assertFalse(hasattr(model, "outcome_transform"))
549+
else:
550+
self.assertIsInstance(model.likelihood, GaussianLikelihood)
551+
self.assertIsInstance(model.outcome_transform, Standardize)
552+
self.assertFalse(hasattr(model, "input_transform"))
553+
554+
def test_ensemble_map_saas_validation(self) -> None:
555+
with self.assertRaisesRegex(ValueError, "Expected taus to be of shape"):
556+
EnsembleMapSaasGP(
557+
train_X=torch.rand(5, 3),
558+
train_Y=torch.rand(5, 1),
559+
num_taus=3,
560+
taus=torch.rand(2),
561+
)
562+
with self.assertRaisesRegex(UnsupportedError, "only supports single-output"):
563+
EnsembleMapSaasGP(train_X=torch.rand(5, 3), train_Y=torch.rand(5, 2))
564+
with self.assertRaisesRegex(UnsupportedError, "only supports 2D inputs"):
565+
EnsembleMapSaasGP(train_X=torch.rand(2, 5, 3), train_Y=torch.rand(2, 5, 1))
566+
513567

514568
class TestAdditiveMapSaasSingleTaskGP(BotorchTestCase):
515569
def _get_data_and_model(

0 commit comments

Comments
 (0)