Skip to content

Commit 082057f

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Deprecate get_fitted_map_saas_ensemble in favor of EnsembleMapSaasGP (#3036)
Summary: With `EnsembleMapSaasGP`, we no longer need a helper that constructs a `SaasFullyBayasianSingleTaskGP` from individually fitted models. Differential Revision: D83782823
1 parent e390e13 commit 082057f

File tree

2 files changed

+36
-141
lines changed

2 files changed

+36
-141
lines changed

botorch/fit.py

Lines changed: 20 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@
88

99
from __future__ import annotations
1010

11+
import warnings
1112
from collections.abc import Callable, Sequence
1213
from copy import deepcopy
1314
from functools import partial
1415
from itertools import filterfalse
1516
from typing import Any
1617
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage
1718

18-
import torch
19-
2019
from botorch.exceptions.errors import ModelFittingError, UnsupportedError
2120
from botorch.exceptions.warnings import OptimizationWarning
2221
from botorch.logging import logger
@@ -27,7 +26,7 @@
2726
SaasFullyBayesianSingleTaskGP,
2827
)
2928
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
30-
from botorch.models.map_saas import get_map_saas_model
29+
from botorch.models.map_saas import EnsembleMapSaasSingleTaskGP, get_map_saas_model
3130
from botorch.models.model_list_gp_regression import ModelListGP
3231
from botorch.models.transforms.input import InputTransform
3332
from botorch.models.transforms.outcome import OutcomeTransform
@@ -45,6 +44,7 @@
4544
TensorCheckpoint,
4645
)
4746
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
47+
from botorch.utils.types import _DefaultType, DEFAULT
4848
from gpytorch.likelihoods import Likelihood
4949
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
5050
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
@@ -53,7 +53,6 @@
5353
from linear_operator.utils.errors import NotPSDError
5454
from pyro.infer.mcmc import MCMC, NUTS
5555
from torch import device, Tensor
56-
from torch.distributions import HalfCauchy
5756
from torch.nn import Parameter
5857
from torch.utils.data import DataLoader
5958

@@ -443,13 +442,15 @@ def get_fitted_map_saas_ensemble(
443442
train_Y: Tensor,
444443
train_Yvar: Tensor | None = None,
445444
input_transform: InputTransform | None = None,
446-
outcome_transform: OutcomeTransform | None = None,
445+
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
447446
taus: Tensor | list[float] | None = None,
448447
num_taus: int = 4,
449448
optimizer_kwargs: dict[str, Any] | None = None,
450449
) -> SaasFullyBayesianSingleTaskGP:
451450
"""Get a fitted SAAS ensemble using several different tau values.
452451
452+
DEPRECATED: Please use `EnsembleMapSaasSingleTaskGP` directly!
453+
453454
Args:
454455
train_X: Tensor of shape `n x d` with training inputs.
455456
train_Y: Tensor of shape `n x 1` with training targets.
@@ -464,57 +465,23 @@ def get_fitted_map_saas_ensemble(
464465
to fit_gpytorch_mll.
465466
466467
Returns:
467-
A fitted SaasFullyBayesianSingleTaskGP with a Matern kernel.
468+
A fitted EnsembleMapSaasSingleTaskGP with a Matern kernel.
468469
"""
469-
tkwargs = {"device": train_X.device, "dtype": train_X.dtype}
470-
if taus is None:
471-
taus = HalfCauchy(0.1).sample([num_taus]).to(**tkwargs)
472-
num_samples = len(taus)
473-
if num_samples == 1:
474-
raise ValueError(
475-
"Use `get_fitted_map_saas_model` if you only specify one value of tau"
476-
)
477-
478-
mean = torch.zeros(num_samples, **tkwargs)
479-
outputscale = torch.zeros(num_samples, **tkwargs)
480-
lengthscale = torch.zeros(num_samples, train_X.shape[-1], **tkwargs)
481-
noise = torch.zeros(num_samples, **tkwargs)
482-
483-
# Fit a model for each tau and save the hyperparameters
484-
for i, tau in enumerate(taus):
485-
model = get_fitted_map_saas_model(
486-
train_X,
487-
train_Y,
488-
train_Yvar=train_Yvar,
489-
input_transform=input_transform,
490-
outcome_transform=outcome_transform,
491-
tau=tau,
492-
optimizer_kwargs=optimizer_kwargs,
493-
)
494-
mean[i] = model.mean_module.constant.detach().clone()
495-
outputscale[i] = model.covar_module.outputscale.detach().clone()
496-
lengthscale[i, :] = model.covar_module.base_kernel.lengthscale.detach().clone()
497-
if train_Yvar is None:
498-
noise[i] = model.likelihood.noise.detach().clone()
499-
500-
# Load the samples into a fully Bayesian SAAS model
501-
ensemble_model = SaasFullyBayesianSingleTaskGP(
470+
warnings.warn(
471+
"get_fitted_map_saas_ensemble is deprecated and will be removed in v0.17. "
472+
"Please use EnsembleMapSaasSingleTaskGP instead!",
473+
DeprecationWarning,
474+
stacklevel=2,
475+
)
476+
model = EnsembleMapSaasSingleTaskGP(
502477
train_X=train_X,
503478
train_Y=train_Y,
504479
train_Yvar=train_Yvar,
505-
input_transform=(
506-
input_transform.train() if input_transform is not None else None
507-
),
480+
num_taus=num_taus,
481+
taus=taus,
482+
input_transform=input_transform,
508483
outcome_transform=outcome_transform,
509484
)
510-
mcmc_samples = {
511-
"mean": mean,
512-
"outputscale": outputscale,
513-
"lengthscale": lengthscale,
514-
}
515-
if train_Yvar is None:
516-
mcmc_samples["noise"] = noise
517-
ensemble_model.train()
518-
ensemble_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
519-
ensemble_model.eval()
520-
return ensemble_model
485+
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
486+
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
487+
return model

test/models/test_map_saas.py

Lines changed: 16 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
get_fitted_map_saas_ensemble,
2020
get_fitted_map_saas_model,
2121
)
22-
from botorch.models import SaasFullyBayesianSingleTaskGP, SingleTaskGP
22+
from botorch.models import SingleTaskGP
2323
from botorch.models.map_saas import (
2424
add_saas_prior,
2525
AdditiveMapSaasSingleTaskGP,
@@ -299,93 +299,21 @@ def test_get_saas_model(self) -> None:
299299
self.assertTrue(loss < loss_short)
300300

301301
def test_get_saas_ensemble(self) -> None:
302-
for infer_noise, taus in itertools.product([True, False], [None, [0.1, 0.2]]):
303-
tkwargs = {"device": self.device, "dtype": torch.double}
304-
train_X, train_Y, _ = self._get_data_hardcoded(**tkwargs)
305-
d = train_X.shape[-1]
306-
train_Yvar = (
307-
None
308-
if infer_noise
309-
else 0.1 * torch.arange(len(train_X), **tkwargs).unsqueeze(-1)
310-
)
311-
# Fit without specifying tau
312-
with torch.random.fork_rng():
313-
torch.manual_seed(0)
314-
model = get_fitted_map_saas_ensemble(
315-
train_X=train_X,
316-
train_Y=train_Y,
317-
train_Yvar=train_Yvar,
318-
input_transform=Normalize(d=d),
319-
outcome_transform=Standardize(m=1),
320-
taus=taus,
321-
)
322-
self.assertIsInstance(model, SaasFullyBayesianSingleTaskGP)
323-
num_taus = 4 if taus is None else len(taus)
324-
self.assertEqual(
325-
model.covar_module.base_kernel.lengthscale.shape,
326-
torch.Size([num_taus, 1, d]),
327-
)
328-
self.assertEqual(model.batch_shape, torch.Size([num_taus]))
329-
# Make sure the lengthscales are reasonable
330-
self.assertGreater(
331-
model.covar_module.base_kernel.lengthscale[..., 1:].min(), 50
332-
)
333-
self.assertLess(
334-
model.covar_module.base_kernel.lengthscale[..., 0].max(), 10
335-
)
336-
337-
# testing optimizer_options: short optimization run with maxiter = 3
338-
with torch.random.fork_rng():
339-
torch.manual_seed(0)
340-
fit_gpytorch_mll_mock = mock.Mock(wraps=fit_gpytorch_mll)
341-
with mock.patch(
342-
"botorch.fit.fit_gpytorch_mll",
343-
new=fit_gpytorch_mll_mock,
344-
):
345-
maxiter = 3
346-
model_short = get_fitted_map_saas_ensemble(
347-
train_X=train_X,
348-
train_Y=train_Y,
349-
train_Yvar=train_Yvar,
350-
input_transform=Normalize(d=d),
351-
outcome_transform=Standardize(m=1),
352-
taus=taus,
353-
optimizer_kwargs={"options": {"maxiter": maxiter}},
354-
)
355-
kwargs = fit_gpytorch_mll_mock.call_args.kwargs
356-
# fit_gpytorch_mll has "option" kwarg, not "optimizer_options"
357-
self.assertEqual(
358-
kwargs["optimizer_kwargs"]["options"]["maxiter"], maxiter
359-
)
360-
361-
# compute sum of marginal likelihoods of ensemble after short run
362-
# NOTE: We can't put MLL in train mode here since
363-
# SaasFullyBayesianSingleTaskGP requires NUTS for training.
364-
mll_short = ExactMarginalLogLikelihood(
365-
model=model_short, likelihood=model_short.likelihood
302+
train_X, train_Y, _ = self._get_data_hardcoded(device=self.device)
303+
with self.assertWarnsRegex(
304+
DeprecationWarning, "EnsembleMapSaasSingleTaskGP"
305+
), mock.patch("botorch.fit.fit_gpytorch_mll") as mock_fit:
306+
model = get_fitted_map_saas_ensemble(
307+
train_X=train_X,
308+
train_Y=train_Y,
309+
input_transform=Normalize(d=train_X.shape[-1]),
310+
outcome_transform=Standardize(m=1, batch_shape=torch.Size([4])),
311+
optimizer_kwargs={"options": {"maxiter": 3}},
366312
)
367-
train_inputs = mll_short.model.train_inputs
368-
train_targets = mll_short.model.train_targets
369-
loss_short = -mll_short(model_short(*train_inputs), train_targets)
370-
# compute sum of marginal likelihoods of ensemble after standard run
371-
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
372-
# reusing train_inputs and train_targets, since the transforms are the same
373-
loss = -mll(model(*train_inputs), train_targets)
374-
# the longer running optimization should have smaller loss than the shorter
375-
self.assertLess((loss - loss_short).max(), 0.0)
376-
377-
# test error message
378-
with self.assertRaisesRegex(
379-
ValueError, "if you only specify one value of tau"
380-
):
381-
model_short = get_fitted_map_saas_ensemble(
382-
train_X=train_X,
383-
train_Y=train_Y,
384-
train_Yvar=train_Yvar,
385-
input_transform=Normalize(d=d),
386-
outcome_transform=Standardize(m=1),
387-
taus=[0.1],
388-
)
313+
self.assertEqual(
314+
mock_fit.call_args.kwargs["optimizer_kwargs"], {"options": {"maxiter": 3}}
315+
)
316+
self.assertIsInstance(model, EnsembleMapSaasSingleTaskGP)
389317

390318
def test_input_transform_in_train(self) -> None:
391319
train_X, train_Y, test_X = self._get_data()
@@ -522,7 +450,7 @@ def test_batch_model_fitting(self) -> None:
522450

523451
@mock_optimize
524452
def test_emsemble_map_saas(self) -> None:
525-
train_X, train_Y, test_X = self._get_data()
453+
train_X, train_Y, test_X = self._get_data(device=self.device)
526454
d = train_X.shape[-1]
527455
num_taus = 8
528456
for with_options in (False, True):

0 commit comments

Comments
 (0)