Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 20 additions & 53 deletions botorch/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@

from __future__ import annotations

import warnings
from collections.abc import Callable, Sequence
from copy import deepcopy
from functools import partial
from itertools import filterfalse
from typing import Any
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage

import torch

from botorch.exceptions.errors import ModelFittingError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.logging import logger
Expand All @@ -27,7 +26,7 @@
SaasFullyBayesianSingleTaskGP,
)
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.map_saas import get_map_saas_model
from botorch.models.map_saas import EnsembleMapSaasSingleTaskGP, get_map_saas_model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
Expand All @@ -45,6 +44,7 @@
TensorCheckpoint,
)
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
Expand All @@ -53,7 +53,6 @@
from linear_operator.utils.errors import NotPSDError
from pyro.infer.mcmc import MCMC, NUTS
from torch import device, Tensor
from torch.distributions import HalfCauchy
from torch.nn import Parameter
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -443,13 +442,15 @@ def get_fitted_map_saas_ensemble(
train_Y: Tensor,
train_Yvar: Tensor | None = None,
input_transform: InputTransform | None = None,
outcome_transform: OutcomeTransform | None = None,
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
taus: Tensor | list[float] | None = None,
num_taus: int = 4,
optimizer_kwargs: dict[str, Any] | None = None,
) -> SaasFullyBayesianSingleTaskGP:
"""Get a fitted SAAS ensemble using several different tau values.

DEPRECATED: Please use `EnsembleMapSaasSingleTaskGP` directly!

Args:
train_X: Tensor of shape `n x d` with training inputs.
train_Y: Tensor of shape `n x 1` with training targets.
Expand All @@ -464,57 +465,23 @@ def get_fitted_map_saas_ensemble(
to fit_gpytorch_mll.

Returns:
A fitted SaasFullyBayesianSingleTaskGP with a Matern kernel.
A fitted EnsembleMapSaasSingleTaskGP with a Matern kernel.
"""
tkwargs = {"device": train_X.device, "dtype": train_X.dtype}
if taus is None:
taus = HalfCauchy(0.1).sample([num_taus]).to(**tkwargs)
num_samples = len(taus)
if num_samples == 1:
raise ValueError(
"Use `get_fitted_map_saas_model` if you only specify one value of tau"
)

mean = torch.zeros(num_samples, **tkwargs)
outputscale = torch.zeros(num_samples, **tkwargs)
lengthscale = torch.zeros(num_samples, train_X.shape[-1], **tkwargs)
noise = torch.zeros(num_samples, **tkwargs)

# Fit a model for each tau and save the hyperparameters
for i, tau in enumerate(taus):
model = get_fitted_map_saas_model(
train_X,
train_Y,
train_Yvar=train_Yvar,
input_transform=input_transform,
outcome_transform=outcome_transform,
tau=tau,
optimizer_kwargs=optimizer_kwargs,
)
mean[i] = model.mean_module.constant.detach().clone()
outputscale[i] = model.covar_module.outputscale.detach().clone()
lengthscale[i, :] = model.covar_module.base_kernel.lengthscale.detach().clone()
if train_Yvar is None:
noise[i] = model.likelihood.noise.detach().clone()

# Load the samples into a fully Bayesian SAAS model
ensemble_model = SaasFullyBayesianSingleTaskGP(
warnings.warn(
"get_fitted_map_saas_ensemble is deprecated and will be removed in v0.17. "
"Please use EnsembleMapSaasSingleTaskGP instead!",
DeprecationWarning,
stacklevel=2,
)
model = EnsembleMapSaasSingleTaskGP(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=(
input_transform.train() if input_transform is not None else None
),
num_taus=num_taus,
taus=taus,
input_transform=input_transform,
outcome_transform=outcome_transform,
)
mcmc_samples = {
"mean": mean,
"outputscale": outputscale,
"lengthscale": lengthscale,
}
if train_Yvar is None:
mcmc_samples["noise"] = noise
ensemble_model.train()
ensemble_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
ensemble_model.eval()
return ensemble_model
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
return model
104 changes: 16 additions & 88 deletions test/models/test_map_saas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_fitted_map_saas_ensemble,
get_fitted_map_saas_model,
)
from botorch.models import SaasFullyBayesianSingleTaskGP, SingleTaskGP
from botorch.models import SingleTaskGP
from botorch.models.map_saas import (
add_saas_prior,
AdditiveMapSaasSingleTaskGP,
Expand Down Expand Up @@ -299,93 +299,21 @@ def test_get_saas_model(self) -> None:
self.assertTrue(loss < loss_short)

def test_get_saas_ensemble(self) -> None:
for infer_noise, taus in itertools.product([True, False], [None, [0.1, 0.2]]):
tkwargs = {"device": self.device, "dtype": torch.double}
train_X, train_Y, _ = self._get_data_hardcoded(**tkwargs)
d = train_X.shape[-1]
train_Yvar = (
None
if infer_noise
else 0.1 * torch.arange(len(train_X), **tkwargs).unsqueeze(-1)
)
# Fit without specifying tau
with torch.random.fork_rng():
torch.manual_seed(0)
model = get_fitted_map_saas_ensemble(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=Normalize(d=d),
outcome_transform=Standardize(m=1),
taus=taus,
)
self.assertIsInstance(model, SaasFullyBayesianSingleTaskGP)
num_taus = 4 if taus is None else len(taus)
self.assertEqual(
model.covar_module.base_kernel.lengthscale.shape,
torch.Size([num_taus, 1, d]),
)
self.assertEqual(model.batch_shape, torch.Size([num_taus]))
# Make sure the lengthscales are reasonable
self.assertGreater(
model.covar_module.base_kernel.lengthscale[..., 1:].min(), 50
)
self.assertLess(
model.covar_module.base_kernel.lengthscale[..., 0].max(), 10
)

# testing optimizer_options: short optimization run with maxiter = 3
with torch.random.fork_rng():
torch.manual_seed(0)
fit_gpytorch_mll_mock = mock.Mock(wraps=fit_gpytorch_mll)
with mock.patch(
"botorch.fit.fit_gpytorch_mll",
new=fit_gpytorch_mll_mock,
):
maxiter = 3
model_short = get_fitted_map_saas_ensemble(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=Normalize(d=d),
outcome_transform=Standardize(m=1),
taus=taus,
optimizer_kwargs={"options": {"maxiter": maxiter}},
)
kwargs = fit_gpytorch_mll_mock.call_args.kwargs
# fit_gpytorch_mll has "option" kwarg, not "optimizer_options"
self.assertEqual(
kwargs["optimizer_kwargs"]["options"]["maxiter"], maxiter
)

# compute sum of marginal likelihoods of ensemble after short run
# NOTE: We can't put MLL in train mode here since
# SaasFullyBayesianSingleTaskGP requires NUTS for training.
mll_short = ExactMarginalLogLikelihood(
model=model_short, likelihood=model_short.likelihood
train_X, train_Y, _ = self._get_data_hardcoded(device=self.device)
with self.assertWarnsRegex(
DeprecationWarning, "EnsembleMapSaasSingleTaskGP"
), mock.patch("botorch.fit.fit_gpytorch_mll") as mock_fit:
model = get_fitted_map_saas_ensemble(
train_X=train_X,
train_Y=train_Y,
input_transform=Normalize(d=train_X.shape[-1]),
outcome_transform=Standardize(m=1, batch_shape=torch.Size([4])),
optimizer_kwargs={"options": {"maxiter": 3}},
)
train_inputs = mll_short.model.train_inputs
train_targets = mll_short.model.train_targets
loss_short = -mll_short(model_short(*train_inputs), train_targets)
# compute sum of marginal likelihoods of ensemble after standard run
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
# reusing train_inputs and train_targets, since the transforms are the same
loss = -mll(model(*train_inputs), train_targets)
# the longer running optimization should have smaller loss than the shorter
self.assertLess((loss - loss_short).max(), 0.0)

# test error message
with self.assertRaisesRegex(
ValueError, "if you only specify one value of tau"
):
model_short = get_fitted_map_saas_ensemble(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
input_transform=Normalize(d=d),
outcome_transform=Standardize(m=1),
taus=[0.1],
)
self.assertEqual(
mock_fit.call_args.kwargs["optimizer_kwargs"], {"options": {"maxiter": 3}}
)
self.assertIsInstance(model, EnsembleMapSaasSingleTaskGP)

def test_input_transform_in_train(self) -> None:
train_X, train_Y, test_X = self._get_data()
Expand Down Expand Up @@ -522,7 +450,7 @@ def test_batch_model_fitting(self) -> None:

@mock_optimize
def test_emsemble_map_saas(self) -> None:
train_X, train_Y, test_X = self._get_data()
train_X, train_Y, test_X = self._get_data(device=self.device)
d = train_X.shape[-1]
num_taus = 8
for with_options in (False, True):
Expand Down