Skip to content

Commit b77590c

Browse files
VilockLifacebook-github-bot
authored andcommitted
Embeding fidelity kernels into SingleTaskGP (#181)
Summary: Pull Request resolved: #181 We write SingleTaskMultiFidelityGP as a subclass of SingleTaskGP to deal with GP models that have fidelity parameters. The kernel used in this new GP model comes from Section 5.3 of https://arxiv.org/abs/1903.04703. We also added mean_module and covar_module kwargs to the SingleTaskGP constructor that default to None. Reviewed By: Balandat Differential Revision: D15731799 fbshipit-source-id: 5835c5843461bcdf30430e5aae5d2696867f3667
1 parent 3910976 commit b77590c

File tree

5 files changed

+471
-13
lines changed

5 files changed

+471
-13
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4+
5+
from .gp_regression_fidelity import SingleTaskMultiFidelityGP
6+
7+
8+
__all__ = ["SingleTaskMultiFidelityGP"]
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#! /usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4+
5+
r"""
6+
Gaussian Process Regression models based on GPyTorch models.
7+
"""
8+
9+
from typing import Optional
10+
11+
import torch
12+
from botorch.exceptions import UnsupportedError
13+
from botorch.models.fidelity_kernels.downsampling_kernel import DownsamplingKernel
14+
from botorch.models.fidelity_kernels.exponential_decay_kernel import ExpDecayKernel
15+
from gpytorch.kernels.rbf_kernel import RBFKernel
16+
from gpytorch.kernels.scale_kernel import ScaleKernel
17+
from gpytorch.likelihoods.likelihood import Likelihood
18+
from gpytorch.priors.torch_priors import GammaPrior
19+
from torch import Tensor
20+
21+
from ..gp_regression import SingleTaskGP
22+
23+
24+
class SingleTaskMultiFidelityGP(SingleTaskGP):
25+
r"""A single task multi-fidelity GP model.
26+
27+
A sub-class of SingleTaskGP model. By default the last two dimensions of train_X
28+
are the fidelity parameters: training iterations, training data points.
29+
The kernel comes from this paper `https://arxiv.org/abs/1903.04703`
30+
31+
Args:
32+
train_X: A `n x (d + s)` or `batch_shape x n x (d + s) ` (batch mode) tensor
33+
of training features, s is the dimension of the fidelity parameters.
34+
train_Y: A `n x (o)` or `batch_shape x n x (o)` (batch mode) tensor of
35+
training observations.
36+
train_iteration_fidelity: An indicator of whether we have the training
37+
iteration fidelity variable.
38+
train_data_fidelity: An indicator of whether we have the downsampling
39+
fidelity variable. If train_iteration_fidelity and train_data_fidelity
40+
are both True, the last and second last columns are treated as the
41+
training data points fidelity parameter and training iteration
42+
number fidelity parameter respectively. Otherwise the last column of
43+
train_X is treated as the fidelity parameter with True indicator.
44+
We assume train_X has at least one fidelity parameter.
45+
likelihood: A likelihood. If omitted, use a standard
46+
GaussianLikelihood with inferred noise level.
47+
48+
Example:
49+
>>> train_X = torch.rand(20, 4)
50+
>>> train_Y = train_X.pow(2).sum(dim=-1)
51+
>>> model = SingleTaskMultiFidelityGP(train_X, train_Y)
52+
"""
53+
54+
def __init__(
55+
self,
56+
train_X: Tensor,
57+
train_Y: Tensor,
58+
train_iteration_fidelity: bool = True,
59+
train_data_fidelity: bool = True,
60+
likelihood: Optional[Likelihood] = None,
61+
) -> None:
62+
train_X, train_Y, _ = self._set_dimensions(train_X=train_X, train_Y=train_Y)
63+
num_fidelity = train_iteration_fidelity + train_data_fidelity
64+
ard_num_dims = train_X.shape[-1] - num_fidelity
65+
active_dimsX = list(range(train_X.shape[-1] - num_fidelity))
66+
rbf_kernel = RBFKernel(
67+
ard_num_dims=ard_num_dims,
68+
batch_shape=self._aug_batch_shape,
69+
lengthscale_prior=GammaPrior(3.0, 6.0),
70+
active_dims=active_dimsX,
71+
)
72+
exp_kernel = ExpDecayKernel(
73+
batch_shape=self._aug_batch_shape,
74+
lengthscale_prior=GammaPrior(3.0, 6.0),
75+
offset_prior=GammaPrior(3.0, 6.0),
76+
power_prior=GammaPrior(3.0, 6.0),
77+
)
78+
ds_kernel = DownsamplingKernel(
79+
batch_shape=self._aug_batch_shape,
80+
offset_prior=GammaPrior(3.0, 6.0),
81+
power_prior=GammaPrior(3.0, 6.0),
82+
)
83+
if train_iteration_fidelity and train_data_fidelity:
84+
active_dimsS1 = [train_X.shape[-1] - 1]
85+
active_dimsS2 = [train_X.shape[-1] - 2]
86+
exp_kernel.active_dims = torch.tensor(active_dimsS1)
87+
ds_kernel.active_dims = torch.tensor(active_dimsS2)
88+
kernel = rbf_kernel * exp_kernel * ds_kernel
89+
elif train_iteration_fidelity or train_data_fidelity:
90+
active_dimsS = [train_X.shape[-1] - 1]
91+
if train_iteration_fidelity:
92+
exp_kernel.active_dims = torch.tensor(active_dimsS)
93+
kernel = rbf_kernel * exp_kernel
94+
else:
95+
ds_kernel.active_dims = torch.tensor(active_dimsS)
96+
kernel = rbf_kernel * ds_kernel
97+
else:
98+
raise UnsupportedError("You should have at least one fidelity parameter.")
99+
covar_module = ScaleKernel(
100+
kernel,
101+
batch_shape=self._aug_batch_shape,
102+
outputscale_prior=GammaPrior(2.0, 0.15),
103+
)
104+
super().__init__(train_X=train_X, train_Y=train_Y, covar_module=covar_module)
105+
self.to(train_X)

botorch/models/gp_regression.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from gpytorch.likelihoods.noise_models import HeteroskedasticNoise
2323
from gpytorch.means.constant_mean import ConstantMean
2424
from gpytorch.models.exact_gp import ExactGP
25+
from gpytorch.module import Module
2526
from gpytorch.priors.smoothed_box_prior import SmoothedBoxPrior
2627
from gpytorch.priors.torch_priors import GammaPrior
2728
from torch import Tensor
@@ -52,7 +53,11 @@ class SingleTaskGP(BatchedMultiOutputGPyTorchModel, ExactGP):
5253
"""
5354

5455
def __init__(
55-
self, train_X: Tensor, train_Y: Tensor, likelihood: Optional[Likelihood] = None
56+
self,
57+
train_X: Tensor,
58+
train_Y: Tensor,
59+
likelihood: Optional[Likelihood] = None,
60+
covar_module: Optional[Module] = None,
5661
) -> None:
5762
r"""A single-task exact GP model.
5863
@@ -63,13 +68,14 @@ def __init__(
6368
training observations.
6469
likelihood: A likelihood. If omitted, use a standard
6570
GaussianLikelihood with inferred noise level.
71+
covar_module: The covariance (kernel) matrix. If omitted, use the
72+
MaternKernel.
6673
6774
Example:
6875
>>> train_X = torch.rand(20, 2)
6976
>>> train_Y = torch.sin(train_X[:, 0]) + torch.cos(train_X[:, 1])
7077
>>> model = SingleTaskGP(train_X, train_Y)
7178
"""
72-
ard_num_dims = train_X.shape[-1]
7379
train_X, train_Y, _ = self._set_dimensions(train_X=train_X, train_Y=train_Y)
7480
train_X, train_Y, _ = multioutput_to_batch_mode_transform(
7581
train_X=train_X, train_Y=train_Y, num_outputs=self._num_outputs
@@ -90,16 +96,19 @@ def __init__(
9096
self._is_custom_likelihood = True
9197
ExactGP.__init__(self, train_X, train_Y, likelihood)
9298
self.mean_module = ConstantMean(batch_shape=self._aug_batch_shape)
93-
self.covar_module = ScaleKernel(
94-
MaternKernel(
95-
nu=2.5,
96-
ard_num_dims=ard_num_dims,
99+
if covar_module is None:
100+
self.covar_module = ScaleKernel(
101+
MaternKernel(
102+
nu=2.5,
103+
ard_num_dims=train_X.shape[-1],
104+
batch_shape=self._aug_batch_shape,
105+
lengthscale_prior=GammaPrior(3.0, 6.0),
106+
),
97107
batch_shape=self._aug_batch_shape,
98-
lengthscale_prior=GammaPrior(3.0, 6.0),
99-
),
100-
batch_shape=self._aug_batch_shape,
101-
outputscale_prior=GammaPrior(2.0, 0.15),
102-
)
108+
outputscale_prior=GammaPrior(2.0, 0.15),
109+
)
110+
else:
111+
self.covar_module = covar_module
103112
self.to(train_X)
104113

105114
def forward(self, x: Tensor) -> MultivariateNormal:
@@ -136,7 +145,6 @@ def __init__(self, train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor) -> None
136145
>>> train_Yvar = torch.full_like(train_Y, 0.2)
137146
>>> model = FixedNoiseGP(train_X, train_Y, train_Yvar)
138147
"""
139-
ard_num_dims = train_X.shape[-1]
140148
train_X, train_Y, train_Yvar = self._set_dimensions(
141149
train_X=train_X, train_Y=train_Y, train_Yvar=train_Yvar
142150
)
@@ -156,7 +164,7 @@ def __init__(self, train_X: Tensor, train_Y: Tensor, train_Yvar: Tensor) -> None
156164
self.covar_module = ScaleKernel(
157165
base_kernel=MaternKernel(
158166
nu=2.5,
159-
ard_num_dims=ard_num_dims,
167+
ard_num_dims=train_X.shape[-1],
160168
batch_shape=self._aug_batch_shape,
161169
lengthscale_prior=GammaPrior(3.0, 6.0),
162170
),

test/models/fidelity/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#! /usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

0 commit comments

Comments
 (0)