Skip to content

Commit 93f2c85

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Deprecate get_fitted_map_saas_ensemble in favor of EnsembleMapSaasGP
Summary: With `EnsembleMapSaasGP`, we no longer need a helper that constructs a `SaasFullyBayasianSingleTaskGP` from individually fitted models. Differential Revision: D83782823
1 parent 4b2363d commit 93f2c85

File tree

2 files changed

+37
-141
lines changed

2 files changed

+37
-141
lines changed

botorch/fit.py

Lines changed: 17 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
from typing import Any
1616
from warnings import catch_warnings, simplefilter, warn_explicit, WarningMessage
1717

18-
import torch
19-
2018
from botorch.exceptions.errors import ModelFittingError, UnsupportedError
2119
from botorch.exceptions.warnings import OptimizationWarning
2220
from botorch.logging import logger
@@ -27,7 +25,7 @@
2725
SaasFullyBayesianSingleTaskGP,
2826
)
2927
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
30-
from botorch.models.map_saas import get_map_saas_model
28+
from botorch.models.map_saas import EnsembleMapSaasGP, get_map_saas_model
3129
from botorch.models.model_list_gp_regression import ModelListGP
3230
from botorch.models.transforms.input import InputTransform
3331
from botorch.models.transforms.outcome import OutcomeTransform
@@ -45,6 +43,7 @@
4543
TensorCheckpoint,
4644
)
4745
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
46+
from botorch.utils.types import _DefaultType, DEFAULT
4847
from gpytorch.likelihoods import Likelihood
4948
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood
5049
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
@@ -53,7 +52,6 @@
5352
from linear_operator.utils.errors import NotPSDError
5453
from pyro.infer.mcmc import MCMC, NUTS
5554
from torch import device, Tensor
56-
from torch.distributions import HalfCauchy
5755
from torch.nn import Parameter
5856
from torch.utils.data import DataLoader
5957

@@ -443,13 +441,15 @@ def get_fitted_map_saas_ensemble(
443441
train_Y: Tensor,
444442
train_Yvar: Tensor | None = None,
445443
input_transform: InputTransform | None = None,
446-
outcome_transform: OutcomeTransform | None = None,
444+
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
447445
taus: Tensor | list[float] | None = None,
448446
num_taus: int = 4,
449447
optimizer_kwargs: dict[str, Any] | None = None,
450448
) -> SaasFullyBayesianSingleTaskGP:
451449
"""Get a fitted SAAS ensemble using several different tau values.
452450
451+
DEPRECATED: Please use `EnsembleMapSaasGP` directly!
452+
453453
Args:
454454
train_X: Tensor of shape `n x d` with training inputs.
455455
train_Y: Tensor of shape `n x 1` with training targets.
@@ -464,57 +464,21 @@ def get_fitted_map_saas_ensemble(
464464
to fit_gpytorch_mll.
465465
466466
Returns:
467-
A fitted SaasFullyBayesianSingleTaskGP with a Matern kernel.
467+
A fitted EnsembleMapSaasGP with a Matern kernel.
468468
"""
469-
tkwargs = {"device": train_X.device, "dtype": train_X.dtype}
470-
if taus is None:
471-
taus = HalfCauchy(0.1).sample([num_taus]).to(**tkwargs)
472-
num_samples = len(taus)
473-
if num_samples == 1:
474-
raise ValueError(
475-
"Use `get_fitted_map_saas_model` if you only specify one value of tau"
476-
)
477-
478-
mean = torch.zeros(num_samples, **tkwargs)
479-
outputscale = torch.zeros(num_samples, **tkwargs)
480-
lengthscale = torch.zeros(num_samples, train_X.shape[-1], **tkwargs)
481-
noise = torch.zeros(num_samples, **tkwargs)
482-
483-
# Fit a model for each tau and save the hyperparameters
484-
for i, tau in enumerate(taus):
485-
model = get_fitted_map_saas_model(
486-
train_X,
487-
train_Y,
488-
train_Yvar=train_Yvar,
489-
input_transform=input_transform,
490-
outcome_transform=outcome_transform,
491-
tau=tau,
492-
optimizer_kwargs=optimizer_kwargs,
493-
)
494-
mean[i] = model.mean_module.constant.detach().clone()
495-
outputscale[i] = model.covar_module.outputscale.detach().clone()
496-
lengthscale[i, :] = model.covar_module.base_kernel.lengthscale.detach().clone()
497-
if train_Yvar is None:
498-
noise[i] = model.likelihood.noise.detach().clone()
499-
500-
# Load the samples into a fully Bayesian SAAS model
501-
ensemble_model = SaasFullyBayesianSingleTaskGP(
469+
logger.warning(
470+
"get_fitted_map_saas_ensemble is deprecated and will be removed in v0.17. "
471+
"Please use EnsembleMapSaasGP instead!"
472+
)
473+
model = EnsembleMapSaasGP(
502474
train_X=train_X,
503475
train_Y=train_Y,
504476
train_Yvar=train_Yvar,
505-
input_transform=(
506-
input_transform.train() if input_transform is not None else None
507-
),
477+
num_taus=num_taus,
478+
taus=taus,
479+
input_transform=input_transform,
508480
outcome_transform=outcome_transform,
509481
)
510-
mcmc_samples = {
511-
"mean": mean,
512-
"outputscale": outputscale,
513-
"lengthscale": lengthscale,
514-
}
515-
if train_Yvar is None:
516-
mcmc_samples["noise"] = noise
517-
ensemble_model.train()
518-
ensemble_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
519-
ensemble_model.eval()
520-
return ensemble_model
482+
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
483+
fit_gpytorch_mll(mll, optimizer_kwargs=optimizer_kwargs)
484+
return model

test/models/test_map_saas.py

Lines changed: 20 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
fit_gpytorch_mll,
1818
get_fitted_map_saas_ensemble,
1919
get_fitted_map_saas_model,
20+
logger,
2021
)
21-
from botorch.models import SaasFullyBayesianSingleTaskGP, SingleTaskGP
22+
from botorch.models import SingleTaskGP
2223
from botorch.models.map_saas import (
2324
add_saas_prior,
2425
AdditiveMapSaasSingleTaskGP,
@@ -291,93 +292,24 @@ def test_get_saas_model(self) -> None:
291292
self.assertTrue(loss < loss_short)
292293

293294
def test_get_saas_ensemble(self) -> None:
294-
for infer_noise, taus in itertools.product([True, False], [None, [0.1, 0.2]]):
295-
tkwargs = {"device": self.device, "dtype": torch.double}
296-
train_X, train_Y, _ = self._get_data_hardcoded(**tkwargs)
297-
d = train_X.shape[-1]
298-
train_Yvar = (
299-
None
300-
if infer_noise
301-
else 0.1 * torch.arange(len(train_X), **tkwargs).unsqueeze(-1)
302-
)
303-
# Fit without specifying tau
304-
with torch.random.fork_rng():
305-
torch.manual_seed(0)
306-
model = get_fitted_map_saas_ensemble(
307-
train_X=train_X,
308-
train_Y=train_Y,
309-
train_Yvar=train_Yvar,
310-
input_transform=Normalize(d=d),
311-
outcome_transform=Standardize(m=1),
312-
taus=taus,
313-
)
314-
self.assertIsInstance(model, SaasFullyBayesianSingleTaskGP)
315-
num_taus = 4 if taus is None else len(taus)
316-
self.assertEqual(
317-
model.covar_module.base_kernel.lengthscale.shape,
318-
torch.Size([num_taus, 1, d]),
319-
)
320-
self.assertEqual(model.batch_shape, torch.Size([num_taus]))
321-
# Make sure the lengthscales are reasonable
322-
self.assertGreater(
323-
model.covar_module.base_kernel.lengthscale[..., 1:].min(), 50
324-
)
325-
self.assertLess(
326-
model.covar_module.base_kernel.lengthscale[..., 0].max(), 10
327-
)
328-
329-
# testing optimizer_options: short optimization run with maxiter = 3
330-
with torch.random.fork_rng():
331-
torch.manual_seed(0)
332-
fit_gpytorch_mll_mock = mock.Mock(wraps=fit_gpytorch_mll)
333-
with mock.patch(
334-
"botorch.fit.fit_gpytorch_mll",
335-
new=fit_gpytorch_mll_mock,
336-
):
337-
maxiter = 3
338-
model_short = get_fitted_map_saas_ensemble(
339-
train_X=train_X,
340-
train_Y=train_Y,
341-
train_Yvar=train_Yvar,
342-
input_transform=Normalize(d=d),
343-
outcome_transform=Standardize(m=1),
344-
taus=taus,
345-
optimizer_kwargs={"options": {"maxiter": maxiter}},
346-
)
347-
kwargs = fit_gpytorch_mll_mock.call_args.kwargs
348-
# fit_gpytorch_mll has "option" kwarg, not "optimizer_options"
349-
self.assertEqual(
350-
kwargs["optimizer_kwargs"]["options"]["maxiter"], maxiter
351-
)
352-
353-
# compute sum of marginal likelihoods of ensemble after short run
354-
# NOTE: We can't put MLL in train mode here since
355-
# SaasFullyBayesianSingleTaskGP requires NUTS for training.
356-
mll_short = ExactMarginalLogLikelihood(
357-
model=model_short, likelihood=model_short.likelihood
295+
train_X, train_Y, _ = self._get_data_hardcoded(device=self.device)
296+
with self.assertLogs(logger=logger, level="WARNING") as logs, mock.patch(
297+
"botorch.fit.fit_gpytorch_mll"
298+
) as mock_fit:
299+
model = get_fitted_map_saas_ensemble(
300+
train_X=train_X,
301+
train_Y=train_Y,
302+
input_transform=Normalize(d=train_X.shape[-1]),
303+
outcome_transform=Standardize(m=1, batch_shape=torch.Size([4])),
304+
optimizer_kwargs={"options": {"maxiter": 3}},
358305
)
359-
train_inputs = mll_short.model.train_inputs
360-
train_targets = mll_short.model.train_targets
361-
loss_short = -mll_short(model_short(*train_inputs), train_targets)
362-
# compute sum of marginal likelihoods of ensemble after standard run
363-
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
364-
# reusing train_inputs and train_targets, since the transforms are the same
365-
loss = -mll(model(*train_inputs), train_targets)
366-
# the longer running optimization should have smaller loss than the shorter
367-
self.assertLess((loss - loss_short).max(), 0.0)
368-
369-
# test error message
370-
with self.assertRaisesRegex(
371-
ValueError, "if you only specify one value of tau"
372-
):
373-
model_short = get_fitted_map_saas_ensemble(
374-
train_X=train_X,
375-
train_Y=train_Y,
376-
train_Yvar=train_Yvar,
377-
input_transform=Normalize(d=d),
378-
outcome_transform=Standardize(m=1),
379-
taus=[0.1],
380-
)
306+
self.assertTrue(
307+
any("use EnsembleMapSaasGP instead" in output for output in logs.output)
308+
)
309+
self.assertEqual(
310+
mock_fit.call_args.kwargs["optimizer_kwargs"], {"options": {"maxiter": 3}}
311+
)
312+
self.assertIsInstance(model, EnsembleMapSaasGP)
381313

382314
def test_input_transform_in_train(self) -> None:
383315
train_X, train_Y, test_X = self._get_data()
@@ -514,7 +446,7 @@ def test_batch_model_fitting(self) -> None:
514446

515447
@mock_optimize
516448
def test_emsemble_map_saas(self) -> None:
517-
train_X, train_Y, test_X = self._get_data()
449+
train_X, train_Y, test_X = self._get_data(device=self.device)
518450
d = train_X.shape[-1]
519451
num_taus = 8
520452
for with_options in (False, True):

0 commit comments

Comments
 (0)