Skip to content

Commit 7e6a7db

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Update the default SingleTaskGP prior (facebook#2610)
Summary: Pull Request resolved: facebook#2610 X-link: pytorch/botorch#2449 Update of the default hyperparameter priors for the SingleTaskGP. Switch from the conventional Scale-Matern kernel with Gamma(3, 6) lengthscale prior is substituted for an RBF Kernel (without a ScaleKernel), and a change from the high-noise Gamma(1.1, 0.05) noise prior of the GaussianLikelihood to a LogNormal prior that prefers lower values. The change is made in accordance with the findings of [1]. The change is made to improve the out-of-the-box performance of the BoTorch models on high-dimensional problems. [1] Carl Hvarfner, Erik Orm Hellsten, Luigi Nardi. _Vanilla Bayesian Optimization Performs Great in High Dimensions_. ICML, 2024. Reviewed By: dme65, saitcakmak Differential Revision: D60080819 fbshipit-source-id: d55ff91dee9949cbd7f5828531644fc001cb3e22
1 parent 8e07000 commit 7e6a7db

File tree

6 files changed

+45
-34
lines changed

6 files changed

+45
-34
lines changed

ax/models/tests/test_botorch_defaults.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import math
910
from copy import deepcopy
1011
from unittest import mock
1112
from unittest.mock import Mock
@@ -66,9 +67,9 @@ def test_get_model(self) -> None:
6667
self.assertIsInstance(model, SingleTaskGP)
6768
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
6869
self.assertEqual(
69-
model.covar_module.base_kernel.lengthscale_prior.concentration, 3.0
70+
model.covar_module.lengthscale_prior.loc, math.log(2.0) / 2 + 2**0.5
7071
)
71-
self.assertEqual(model.covar_module.base_kernel.lengthscale_prior.rate, 6.0)
72+
self.assertEqual(model.covar_module.lengthscale_prior.scale, 3**0.5)
7273
model = _get_model(X=x, Y=y, Yvar=unknown_var, task_feature=1)
7374
self.assertIs(type(model), MultiTaskGP) # Don't accept subclasses.
7475
self.assertIsInstance(model.likelihood, GaussianLikelihood)

ax/models/tests/test_botorch_model.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from botorch.models.transforms.input import Warp
3737
from botorch.utils.datasets import SupervisedDataset
3838
from botorch.utils.objective import get_objective_weights_transform
39+
from gpytorch.kernels.constant_kernel import ConstantKernel
3940
from gpytorch.likelihoods import _GaussianLikelihoodBase
4041
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
4142
from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood
@@ -558,19 +559,12 @@ def test_BotorchModel(
558559

559560
# Test loading state dict
560561
true_state_dict = {
561-
"mean_module.raw_constant": 3.5004,
562-
"covar_module.raw_outputscale": 2.2438,
563-
"covar_module.base_kernel.raw_lengthscale": [
564-
[-0.9274, -0.9274, -0.9274]
565-
],
566-
"covar_module.base_kernel.raw_lengthscale_constraint.lower_bound": 0.1,
567-
"covar_module.base_kernel.raw_lengthscale_constraint.upper_bound": 2.5,
568-
"covar_module.base_kernel.lengthscale_prior.concentration": 3.0,
569-
"covar_module.base_kernel.lengthscale_prior.rate": 6.0,
570-
"covar_module.raw_outputscale_constraint.lower_bound": 0.2,
571-
"covar_module.raw_outputscale_constraint.upper_bound": 2.6,
572-
"covar_module.outputscale_prior.concentration": 2.0,
573-
"covar_module.outputscale_prior.rate": 0.15,
562+
"mean_module.raw_constant": 1.0,
563+
"covar_module.raw_lengthscale": [[0.3548, 0.3548, 0.3548]],
564+
"covar_module.lengthscale_prior._transformed_loc": 1.9635,
565+
"covar_module.lengthscale_prior._transformed_scale": 1.7321,
566+
"covar_module.raw_lengthscale_constraint.lower_bound": 0.0250,
567+
"covar_module.raw_lengthscale_constraint.upper_bound": float("inf"),
574568
}
575569
true_state_dict = {
576570
key: torch.tensor(val, **tkwargs)
@@ -591,8 +585,7 @@ def test_BotorchModel(
591585

592586
# Test for some change in model parameters & buffer for refit_model=True
593587
true_state_dict["mean_module.raw_constant"] += 0.1
594-
true_state_dict["covar_module.raw_outputscale"] += 0.1
595-
true_state_dict["covar_module.base_kernel.raw_lengthscale"] += 0.1
588+
true_state_dict["covar_module.raw_lengthscale"] += 0.1
596589
model = get_and_fit_model(
597590
Xs=Xs1,
598591
Ys=Ys1,
@@ -774,17 +767,16 @@ def test_get_feature_importances_from_botorch_model(self) -> None:
774767
train_X = torch.rand(5, 3, **tkwargs)
775768
train_Y = train_X.sum(dim=-1, keepdim=True)
776769
simple_gp = SingleTaskGP(train_X=train_X, train_Y=train_Y)
777-
simple_gp.covar_module.base_kernel.lengthscale = torch.tensor(
778-
[1, 3, 5], **tkwargs
779-
)
770+
simple_gp.covar_module.lengthscale = torch.tensor([1, 3, 5], **tkwargs)
780771
importances = get_feature_importances_from_botorch_model(simple_gp)
781772
self.assertTrue(np.allclose(importances, np.array([15 / 23, 5 / 23, 3 / 23])))
782773
self.assertEqual(importances.shape, (1, 1, 3))
783-
# Model with no base kernel
784-
simple_gp.covar_module.base_kernel = None
774+
# Model with kernel that has no lengthscales
775+
simple_gp.covar_module = ConstantKernel()
785776
with self.assertRaisesRegex(
786777
NotImplementedError,
787-
"Failed to extract lengthscales from `m.covar_module.base_kernel`",
778+
"Failed to extract lengthscales from `m.covar_module` and "
779+
"`m.covar_module.base_kernel`",
788780
):
789781
get_feature_importances_from_botorch_model(simple_gp)
790782

ax/models/torch/botorch.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,15 +562,21 @@ def get_feature_importances_from_botorch_model(
562562
lengthscales = []
563563
for m in models:
564564
try:
565-
ls = m.covar_module.base_kernel.lengthscale
565+
# this can be a ModelList of a SAAS and STGP, so this is a necessary way
566+
# to get the lengthscale
567+
if hasattr(m.covar_module, "base_kernel"):
568+
ls = m.covar_module.base_kernel.lengthscale
569+
else:
570+
ls = m.covar_module.lengthscale
566571
except AttributeError:
567572
ls = None
568573
if ls is None or ls.shape[-1] != m.train_inputs[0].shape[-1]:
569574
# TODO: We could potentially set the feature importances to NaN in this
570575
# case, but this require knowing the batch dimension of this model.
571576
# Consider supporting in the future.
572577
raise NotImplementedError(
573-
"Failed to extract lengthscales from `m.covar_module.base_kernel`"
578+
"Failed to extract lengthscales from `m.covar_module` "
579+
"and `m.covar_module.base_kernel`"
574580
)
575581
if ls.ndim == 2:
576582
ls = ls.unsqueeze(0)

ax/models/torch/tests/test_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,8 @@ def test_feature_importances(self) -> None:
634634
self.assertEqual(importances.shape, (1, 1, 3))
635635
saas_model = deepcopy(model.surrogate.model)
636636
else:
637-
model.surrogate.model.covar_module.base_kernel.lengthscale = (
638-
torch.tensor([1, 2, 3], **self.tkwargs)
637+
model.surrogate.model.covar_module.lengthscale = torch.tensor(
638+
[1, 2, 3], **self.tkwargs
639639
)
640640
importances = model.feature_importances()
641641
self.assertTrue(
@@ -658,11 +658,12 @@ def test_feature_importances(self) -> None:
658658
)
659659
self.assertEqual(importances.shape, (2, 1, 3))
660660
# Add model we don't support
661-
vanilla_model.covar_module.base_kernel = None
661+
vanilla_model.covar_module = None
662662
model.surrogate._model = vanilla_model # pyre-ignore
663663
with self.assertRaisesRegex(
664664
NotImplementedError,
665-
"Failed to extract lengthscales from `m.covar_module.base_kernel`",
665+
"Failed to extract lengthscales from `m.covar_module` "
666+
"and `m.covar_module.base_kernel`",
666667
):
667668
model.feature_importances()
668669
# Test model is None

ax/plot/tests/test_feature_importances.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def get_sensitivity_values(ax_model: ModelBridge) -> Dict:
4747
4848
Returns map {'metric_name': {'parameter_name': sensitivity_value}}
4949
"""
50-
ls = ax_model.model.model.covar_module.base_kernel.lengthscale.squeeze()
50+
if hasattr(ax_model.model.model.covar_module, "outputscale"):
51+
ls = ax_model.model.model.covar_module.base_kernel.lengthscale.squeeze()
52+
else:
53+
ls = ax_model.model.model.covar_module.lengthscale.squeeze()
5154
if len(ls.shape) > 1:
5255
ls = ls.mean(dim=0)
5356
# pyre-fixme[16]: `float` has no attribute `detach`.

ax/utils/sensitivity/derivative_gp.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ def get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = "rbf") -> Tensor:
3737
D = X.shape[1]
3838
N = X.shape[0]
3939
n = x.shape[0]
40-
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
40+
if hasattr(gp.covar_module, "outputscale"):
41+
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
42+
sigma_f = gp.covar_module.outputscale.detach()
43+
else:
44+
lengthscale = gp.covar_module.lengthscale.detach()
45+
sigma_f = 1.0
4146
if kernel_type == "rbf":
4247
K_xX = gp.covar_module(x, X).evaluate()
4348
part1 = -torch.eye(D, device=x.device, dtype=x.dtype) / lengthscale**2
@@ -52,7 +57,6 @@ def get_KxX_dx(gp: Model, x: Tensor, kernel_type: str = "rbf") -> Tensor:
5257
constant_component = (-5.0 / 3.0) * distance - (5.0 * math.sqrt(5.0) / 3.0) * (
5358
distance**2
5459
)
55-
sigma_f = gp.covar_module.outputscale.detach()
5660
part1 = torch.eye(D, device=lengthscale.device) / lengthscale
5761
part2 = (x1_.view(n, 1, D) - x2_.view(1, N, D)) / distance.unsqueeze(2)
5862
total_k = sigma_f * constant_component * exp_component
@@ -70,8 +74,12 @@ def get_Kxx_dx2(gp: Model, kernel_type: str = "rbf") -> Tensor:
7074
"""
7175
X = gp.train_inputs[0]
7276
D = X.shape[1]
73-
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
74-
sigma_f = gp.covar_module.outputscale.detach()
77+
if hasattr(gp.covar_module, "outputscale"):
78+
lengthscale = gp.covar_module.base_kernel.lengthscale.detach()
79+
sigma_f = gp.covar_module.outputscale.detach()
80+
else:
81+
lengthscale = gp.covar_module.lengthscale.detach()
82+
sigma_f = 1.0
7583
res = (torch.eye(D, device=lengthscale.device) / lengthscale**2) * sigma_f
7684
if kernel_type == "rbf":
7785
return res

0 commit comments

Comments
 (0)