Skip to content

Commit a36c2e6

Browse files
Daniel Jiangfacebook-github-bot
authored andcommitted
FixedNoiseMultiFidelityGP (#386)
Summary: Pull Request resolved: #386 Adds a FixedNoiseMultiFidelityGP + unit tests. Changes FixedNoiseGP to allow a covar_module to be passed in. Reviewed By: Balandat Differential Revision: D20235817 fbshipit-source-id: 13029b765c5fe41136dae1f7ccdd44f5d50606af
1 parent a6eddcb commit a36c2e6

File tree

3 files changed

+332
-130
lines changed

3 files changed

+332
-130
lines changed

botorch/models/gp_regression.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def __init__(
154154
train_X: Tensor,
155155
train_Y: Tensor,
156156
train_Yvar: Tensor,
157+
covar_module: Optional[Module] = None,
157158
outcome_transform: Optional[OutcomeTransform] = None,
158159
) -> None:
159160
r"""A single-task exact GP model using fixed noise levels.
@@ -189,23 +190,28 @@ def __init__(
189190
self, train_inputs=train_X, train_targets=train_Y, likelihood=likelihood
190191
)
191192
self.mean_module = ConstantMean(batch_shape=self._aug_batch_shape)
192-
self.covar_module = ScaleKernel(
193-
base_kernel=MaternKernel(
194-
nu=2.5,
195-
ard_num_dims=train_X.shape[-1],
193+
if covar_module is None:
194+
self.covar_module = ScaleKernel(
195+
base_kernel=MaternKernel(
196+
nu=2.5,
197+
ard_num_dims=train_X.shape[-1],
198+
batch_shape=self._aug_batch_shape,
199+
lengthscale_prior=GammaPrior(3.0, 6.0),
200+
),
196201
batch_shape=self._aug_batch_shape,
197-
lengthscale_prior=GammaPrior(3.0, 6.0),
198-
),
199-
batch_shape=self._aug_batch_shape,
200-
outputscale_prior=GammaPrior(2.0, 0.15),
201-
)
202+
outputscale_prior=GammaPrior(2.0, 0.15),
203+
)
204+
self._subset_batch_dict = {
205+
"mean_module.constant": -2,
206+
"covar_module.raw_outputscale": -1,
207+
"covar_module.base_kernel.raw_lengthscale": -3,
208+
}
209+
else:
210+
self.covar_module = covar_module
211+
# TODO: Allow subsetting of other covar modules
202212
if outcome_transform is not None:
203213
self.outcome_transform = outcome_transform
204-
self._subset_batch_dict = {
205-
"mean_module.constant": -2,
206-
"covar_module.raw_outputscale": -1,
207-
"covar_module.base_kernel.raw_lengthscale": -3,
208-
}
214+
209215
self.to(train_X)
210216

211217
def fantasize(

botorch/models/gp_regression_fidelity.py

Lines changed: 200 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Optional
17+
from typing import Dict, Optional, Tuple
1818

19+
import torch
1920
from gpytorch.kernels.kernel import ProductKernel
2021
from gpytorch.kernels.rbf_kernel import RBFKernel
2122
from gpytorch.kernels.scale_kernel import ScaleKernel
@@ -24,7 +25,7 @@
2425
from torch import Tensor
2526

2627
from ..exceptions.errors import UnsupportedError
27-
from .gp_regression import SingleTaskGP
28+
from .gp_regression import FixedNoiseGP, SingleTaskGP
2829
from .kernels.downsampling import DownsamplingKernel
2930
from .kernels.exponential_decay import ExponentialDecayKernel
3031
from .kernels.linear_truncated_fidelity import LinearTruncatedFidelityKernel
@@ -55,11 +56,15 @@ class SingleTaskMultiFidelityGP(SingleTaskGP):
5556
5/2. Only used when `linear_truncated=True`.
5657
likelihood: A likelihood. If omitted, use a standard GaussianLikelihood
5758
with inferred noise level.
59+
outcome_transform: An outcome transform that is applied to the
60+
training data during instantiation and to the posterior during
61+
inference (that is, the `Posterior` obtained by calling
62+
`.posterior` on the model will be on the original scale).
5863
5964
Example:
6065
>>> train_X = torch.rand(20, 4)
6166
>>> train_Y = train_X.pow(2).sum(dim=-1, keepdim=True)
62-
>>> model = SingleTaskMultiFidelityGP(train_X, train_Y)
67+
>>> model = SingleTaskMultiFidelityGP(train_X, train_Y, data_fidelity=3)
6368
"""
6469

6570
def __init__(
@@ -84,93 +89,211 @@ def __init__(
8489
raise UnsupportedError(
8590
"SingleTaskMultiFidelityGP requires at least one fidelity parameter."
8691
)
87-
if iteration_fidelity is not None and iteration_fidelity < 0:
88-
iteration_fidelity = train_X.size(-1) + iteration_fidelity
89-
if data_fidelity is not None and data_fidelity < 0:
90-
data_fidelity = train_X.size(-1) + data_fidelity
9192
self._set_dimensions(train_X=train_X, train_Y=train_Y)
92-
if linear_truncated:
93-
fidelity_dims = [
94-
i for i in (iteration_fidelity, data_fidelity) if i is not None
95-
]
96-
kernel = LinearTruncatedFidelityKernel(
97-
fidelity_dims=fidelity_dims,
98-
dimension=train_X.size(-1),
99-
nu=nu,
100-
batch_shape=self._aug_batch_shape,
101-
power_prior=GammaPrior(3.0, 3.0),
102-
)
103-
else:
104-
active_dimsX = [
105-
i
106-
for i in range(train_X.size(-1))
107-
if i not in {iteration_fidelity, data_fidelity}
108-
]
109-
kernel = RBFKernel(
110-
ard_num_dims=len(active_dimsX),
111-
batch_shape=self._aug_batch_shape,
112-
lengthscale_prior=GammaPrior(3.0, 6.0),
113-
active_dims=active_dimsX,
114-
)
115-
additional_kernels = []
116-
if iteration_fidelity is not None:
117-
exp_kernel = ExponentialDecayKernel(
118-
batch_shape=self._aug_batch_shape,
119-
lengthscale_prior=GammaPrior(3.0, 6.0),
120-
offset_prior=GammaPrior(3.0, 6.0),
121-
power_prior=GammaPrior(3.0, 6.0),
122-
active_dims=[iteration_fidelity],
123-
)
124-
additional_kernels.append(exp_kernel)
125-
if data_fidelity is not None:
126-
ds_kernel = DownsamplingKernel(
127-
batch_shape=self._aug_batch_shape,
128-
offset_prior=GammaPrior(3.0, 6.0),
129-
power_prior=GammaPrior(3.0, 6.0),
130-
active_dims=[data_fidelity],
131-
)
132-
additional_kernels.append(ds_kernel)
133-
kernel = ProductKernel(kernel, *additional_kernels)
134-
135-
covar_module = ScaleKernel(
136-
kernel,
137-
batch_shape=self._aug_batch_shape,
138-
outputscale_prior=GammaPrior(2.0, 0.15),
93+
covar_module, subset_batch_dict = _setup_multifidelity_covar_module(
94+
dim=train_X.size(-1),
95+
aug_batch_shape=self._aug_batch_shape,
96+
iteration_fidelity=iteration_fidelity,
97+
data_fidelity=data_fidelity,
98+
linear_truncated=linear_truncated,
99+
nu=nu,
139100
)
140101
super().__init__(
141102
train_X=train_X,
142103
train_Y=train_Y,
104+
likelihood=likelihood,
143105
covar_module=covar_module,
144106
outcome_transform=outcome_transform,
145107
)
146-
if linear_truncated:
147-
subset_batch_dict = {
148-
"covar_module.base_kernel.raw_power": -2,
149-
"covar_module.base_kernel.covar_module_unbiased.raw_lengthscale": -3,
150-
"covar_module.base_kernel.covar_module_biased.raw_lengthscale": -3,
151-
}
152-
else:
153-
subset_batch_dict = {
154-
"covar_module.base_kernel.kernels.0.raw_lengthscale": -3,
155-
"covar_module.base_kernel.kernels.1.raw_power": -2,
156-
"covar_module.base_kernel.kernels.1.raw_offset": -2,
157-
}
158-
if iteration_fidelity is not None:
159-
subset_batch_dict = {
160-
"covar_module.base_kernel.kernels.1.raw_lengthscale": -3,
161-
**subset_batch_dict,
162-
}
163-
if data_fidelity is not None:
164-
subset_batch_dict = {
165-
"covar_module.base_kernel.kernels.2.raw_power": -2,
166-
"covar_module.base_kernel.kernels.2.raw_offset": -2,
167-
**subset_batch_dict,
168-
}
169108
self._subset_batch_dict = {
170109
"likelihood.noise_covar.raw_noise": -2,
171110
"mean_module.constant": -2,
172111
"covar_module.raw_outputscale": -1,
173112
**subset_batch_dict,
174113
}
114+
self.to(train_X)
115+
116+
117+
class FixedNoiseMultiFidelityGP(FixedNoiseGP):
118+
r"""A single task multi-fidelity GP model using fixed noise levels.
119+
120+
A FixedNoiseGP model analogue to SingleTaskMultiFidelityGP, using a
121+
DownsamplingKernel for the data fidelity parameter (if present) and
122+
an ExponentialDecayKernel for the iteration fidelity parameter (if present).
123+
124+
This kernel is described in [Wu2019mf]_.
125+
126+
Args:
127+
train_X: A `batch_shape x n x (d + s)` tensor of training features,
128+
where `s` is the dimension of the fidelity parameters (either one
129+
or two).
130+
train_Y: A `batch_shape x n x m` tensor of training observations.
131+
train_Yvar: A `batch_shape x n x m` tensor of observed measurement noise.
132+
iteration_fidelity: The column index for the training iteration fidelity
133+
parameter (optional).
134+
data_fidelity: The column index for the downsampling fidelity parameter
135+
(optional).
136+
linear_truncated: If True, use a `LinearTruncatedFidelityKernel` instead
137+
of the default kernel.
138+
nu: The smoothness parameter for the Matern kernel: either 1/2, 3/2, or
139+
5/2. Only used when `linear_truncated=True`.
140+
outcome_transform: An outcome transform that is applied to the
141+
training data during instantiation and to the posterior during
142+
inference (that is, the `Posterior` obtained by calling
143+
`.posterior` on the model will be on the original scale).
144+
145+
Example:
146+
>>> train_X = torch.rand(20, 4)
147+
>>> train_Y = train_X.pow(2).sum(dim=-1, keepdim=True)
148+
>>> train_Yvar = torch.full_like(train_Y) * 0.01
149+
>>> model = FixedNoiseMultiFidelityGP(
150+
>>> train_X,
151+
>>> train_Y,
152+
>>> train_Yvar,
153+
>>> data_fidelity=3,
154+
>>> )
155+
"""
175156

157+
def __init__(
158+
self,
159+
train_X: Tensor,
160+
train_Y: Tensor,
161+
train_Yvar: Tensor,
162+
iteration_fidelity: Optional[int] = None,
163+
data_fidelity: Optional[int] = None,
164+
linear_truncated: bool = True,
165+
nu: float = 2.5,
166+
outcome_transform: Optional[OutcomeTransform] = None,
167+
) -> None:
168+
if iteration_fidelity is None and data_fidelity is None:
169+
raise UnsupportedError(
170+
"FixedNoiseMultiFidelityGP requires at least one fidelity parameter."
171+
)
172+
self._set_dimensions(train_X=train_X, train_Y=train_Y)
173+
covar_module, subset_batch_dict = _setup_multifidelity_covar_module(
174+
dim=train_X.size(-1),
175+
aug_batch_shape=self._aug_batch_shape,
176+
iteration_fidelity=iteration_fidelity,
177+
data_fidelity=data_fidelity,
178+
linear_truncated=linear_truncated,
179+
nu=nu,
180+
)
181+
super().__init__(
182+
train_X=train_X,
183+
train_Y=train_Y,
184+
train_Yvar=train_Yvar,
185+
covar_module=covar_module,
186+
outcome_transform=outcome_transform,
187+
)
188+
self._subset_batch_dict = {
189+
"likelihood.noise_covar.raw_noise": -2,
190+
"mean_module.constant": -2,
191+
"covar_module.raw_outputscale": -1,
192+
**subset_batch_dict,
193+
}
176194
self.to(train_X)
195+
196+
197+
def _setup_multifidelity_covar_module(
198+
dim: int,
199+
aug_batch_shape: torch.Size,
200+
iteration_fidelity: Optional[int],
201+
data_fidelity: Optional[int],
202+
linear_truncated: bool,
203+
nu: float,
204+
) -> Tuple[ScaleKernel, Dict]:
205+
"""Helper function to get the covariance module and associated subset_batch_dict
206+
for the multifidelity setting.
207+
208+
Args:
209+
dim: The dimensionality of the training data.
210+
aug_batch_shape: The output-augmented batch shape as defined in
211+
`BatchedMultiOutputGPyTorchModel`.
212+
iteration_fidelity: The column index for the training iteration fidelity
213+
parameter (optional).
214+
data_fidelity: The column index for the downsampling fidelity parameter
215+
(optional).
216+
linear_truncated: If True, use a `LinearTruncatedFidelityKernel` instead
217+
of the default kernel.
218+
nu: The smoothness parameter for the Matern kernel: either 1/2, 3/2, or
219+
5/2. Only used when `linear_truncated=True`.
220+
221+
Returns:
222+
The covariance module and subset_batch_dict.
223+
"""
224+
225+
if iteration_fidelity is not None and iteration_fidelity < 0:
226+
iteration_fidelity = dim + iteration_fidelity
227+
if data_fidelity is not None and data_fidelity < 0:
228+
data_fidelity = dim + data_fidelity
229+
230+
if linear_truncated:
231+
fidelity_dims = [
232+
i for i in (iteration_fidelity, data_fidelity) if i is not None
233+
]
234+
kernel = LinearTruncatedFidelityKernel(
235+
fidelity_dims=fidelity_dims,
236+
dimension=dim,
237+
nu=nu,
238+
batch_shape=aug_batch_shape,
239+
power_prior=GammaPrior(3.0, 3.0),
240+
)
241+
else:
242+
active_dimsX = [
243+
i for i in range(dim) if i not in {iteration_fidelity, data_fidelity}
244+
]
245+
kernel = RBFKernel(
246+
ard_num_dims=len(active_dimsX),
247+
batch_shape=aug_batch_shape,
248+
lengthscale_prior=GammaPrior(3.0, 6.0),
249+
active_dims=active_dimsX,
250+
)
251+
additional_kernels = []
252+
if iteration_fidelity is not None:
253+
exp_kernel = ExponentialDecayKernel(
254+
batch_shape=aug_batch_shape,
255+
lengthscale_prior=GammaPrior(3.0, 6.0),
256+
offset_prior=GammaPrior(3.0, 6.0),
257+
power_prior=GammaPrior(3.0, 6.0),
258+
active_dims=[iteration_fidelity],
259+
)
260+
additional_kernels.append(exp_kernel)
261+
if data_fidelity is not None:
262+
ds_kernel = DownsamplingKernel(
263+
batch_shape=aug_batch_shape,
264+
offset_prior=GammaPrior(3.0, 6.0),
265+
power_prior=GammaPrior(3.0, 6.0),
266+
active_dims=[data_fidelity],
267+
)
268+
additional_kernels.append(ds_kernel)
269+
kernel = ProductKernel(kernel, *additional_kernels)
270+
271+
covar_module = ScaleKernel(
272+
kernel, batch_shape=aug_batch_shape, outputscale_prior=GammaPrior(2.0, 0.15)
273+
)
274+
275+
if linear_truncated:
276+
subset_batch_dict = {
277+
"covar_module.base_kernel.raw_power": -2,
278+
"covar_module.base_kernel.covar_module_unbiased.raw_lengthscale": -3,
279+
"covar_module.base_kernel.covar_module_biased.raw_lengthscale": -3,
280+
}
281+
else:
282+
subset_batch_dict = {
283+
"covar_module.base_kernel.kernels.0.raw_lengthscale": -3,
284+
"covar_module.base_kernel.kernels.1.raw_power": -2,
285+
"covar_module.base_kernel.kernels.1.raw_offset": -2,
286+
}
287+
if iteration_fidelity is not None:
288+
subset_batch_dict = {
289+
"covar_module.base_kernel.kernels.1.raw_lengthscale": -3,
290+
**subset_batch_dict,
291+
}
292+
if data_fidelity is not None:
293+
subset_batch_dict = {
294+
"covar_module.base_kernel.kernels.2.raw_power": -2,
295+
"covar_module.base_kernel.kernels.2.raw_offset": -2,
296+
**subset_batch_dict,
297+
}
298+
299+
return covar_module, subset_batch_dict

0 commit comments

Comments
 (0)