Skip to content

Commit 280b8bd

Browse files
VilockLifacebook-github-bot
authored andcommitted
Add linear truncated kernel for multi-fidelity bayesian optimization (#192)
Summary: Pull Request resolved: #192 We add a linear truncated kernel with a polynomial basis. We extended this kernel to deal with two fidelity parameters by adding an interaction term. Reviewed By: Balandat Differential Revision: D15768548 fbshipit-source-id: 093c4f4c2cdb52466117660c0bdc04286cecf9b8
1 parent b77590c commit 280b8bd

File tree

2 files changed

+375
-0
lines changed

2 files changed

+375
-0
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4+
5+
from typing import Any, Optional
6+
7+
import torch
8+
from botorch.exceptions import UnsupportedError
9+
from gpytorch.constraints import Interval, Positive
10+
from gpytorch.kernels import Kernel
11+
from gpytorch.kernels.matern_kernel import MaternKernel
12+
from gpytorch.priors import Prior
13+
from gpytorch.priors.torch_priors import GammaPrior
14+
15+
16+
class LinearTruncatedFidelityKernel(Kernel):
17+
r"""
18+
Computes a covariance matrix based on the Linear truncated kernel
19+
between inputs :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`:
20+
21+
.. math::
22+
23+
\begin{equation*}
24+
k_{\text{LinearTruncated}}(\mathbf{x_1}, \mathbf{x_2}) = k_0
25+
+ c_1(\mathbf{x_1},\mathbf{x_1})k_1 + c_2(\mathbf{x_1},\mathbf{x_2})k_2
26+
+ c_3(\mathbf{x_2},\mathbf{x_2})k_3
27+
\end{equation*}
28+
29+
where
30+
31+
* :math:`k_i(i=0,1,2,3)` are Matern kernels calculated between `\mathbf{x_1}[:-2]`
32+
and `\mathbf{x_2}[:-2]` with different priors.
33+
* :math:`c_1=(1-\mathbf{x_1}[-1])(1-\mathbf{x_2}[-1]))(1+\mathbf{x_1}[-1]
34+
\mathbf{x_2}[-1])^p` is the kernel of bias term, which can be decomposed
35+
into a determistic part and a polynomial kernel.
36+
:math:`c_3` is the same as `c_1` but is calculated from the second last entries
37+
of `\mathbf{x_1}` and `\mathbf{x_2}`.
38+
:math:`c_2` is the interaction term with four deterministic terms and the
39+
polynomial kernel between `\mathbf{x_1}[-2:]` and `\mathbf{x_2}[-2:]`
40+
* :math:`p` is the order of the polynomial kernel.
41+
42+
.. note::
43+
44+
We assume the last two dimensions of input `x` are the fidelity parameters.
45+
46+
Args:
47+
:attr:`nu` (float):
48+
The smoothness parameter fo Matern kernel: either 1/2, 3/2, or 5/2.
49+
Default: '2.5'
50+
:attr:`batch_shape` (torch.Size, optional):
51+
Set this if you want a separate lengthscale for each
52+
batch of input data. It should be `b` if :attr:`x1` is a
53+
`b x n x d` tensor. Default: `torch.Size([])`
54+
:attr:`active_dims` (tuple of ints, optional):
55+
Set this if you want to
56+
compute the covariance of only a few input dimensions. The ints
57+
corresponds to the indices of the dimensions. Default: `None`.
58+
:attr:`lengthscale_prior` (Prior, optional):
59+
Set this if you want to apply a prior to the lengthscale parameter
60+
of Matern kernel `k_0`. Default: `Gamma(1.1, 1/20)`
61+
:attr:`lengthscale_constraint` (Constraint, optional):
62+
Set this if you want to apply a constraint to the lengthscale parameter
63+
of Matern kernel `k_0`. Default: `Positive`
64+
:attr:`lengthscale_2_prior` (Prior, optional):
65+
Set this if you want to apply a prior to the lengthscale parameter
66+
of Matern kernel `k_i(i>0)`. Default: `Gamma(5, 1/20)`
67+
:attr:`lengthscale_2_constraint` (Constraint, optional):
68+
Set this if you want to apply a constraint to the lengthscale parameter
69+
of Matern kernel `k_i(i>0)`. Default: `Positive`
70+
:attr:`power_prior` (Prior, optional):
71+
Set this if you want to apply a prior to the power parameter of
72+
polynomial kernel. Default: `None`
73+
:attr:`power_constraint` (Constraint, optional):
74+
Set this if you want to apply a constraint to the power parameter
75+
polynomial kernel. Default: `Positive`
76+
77+
Attributes:
78+
:attr:`lengthscale` (Tensor):
79+
The lengthscale parameter. Size/shape of parameter.
80+
81+
Example:
82+
>>> x = torch.randn(10, 5)
83+
>>> # Non-batch: Simple option
84+
>>> covar_module = LinearTruncatedFidelityKernel()
85+
>>> covar = covar_module(x) # Output: LazyVariable of size (10 x 10)
86+
>>>
87+
>>> batch_x = torch.randn(2, 10, 5)
88+
>>> # Batch: Simple option
89+
>>> covar_module = LinearTruncatedFidelityKernel(batch_shape = torch.Size([2]))
90+
>>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10)
91+
"""
92+
93+
def __init__(
94+
self,
95+
nu: float = 2.5,
96+
train_iteration_fidelity: bool = True,
97+
train_data_fidelity: bool = True,
98+
lengthscale_prior: Optional[Prior] = None,
99+
power_prior: Optional[Prior] = None,
100+
power_constraint: Optional[Interval] = None,
101+
lengthscale_2_prior: Optional[Prior] = None,
102+
lengthscale_2_constraint: Optional[Interval] = None,
103+
**kwargs: Any,
104+
):
105+
if not train_iteration_fidelity and not train_data_fidelity:
106+
raise UnsupportedError("You should have at least one fidelity parameter.")
107+
if nu not in {0.5, 1.5, 2.5}:
108+
raise ValueError("nu expected to be 0.5, 1.5, or 2.5")
109+
super().__init__(has_lengthscale=True, **kwargs)
110+
self.train_iteration_fidelity = train_iteration_fidelity
111+
self.train_data_fidelity = train_data_fidelity
112+
if power_constraint is None:
113+
power_constraint = Positive()
114+
115+
if lengthscale_prior is None:
116+
self.lengthscale_prior = GammaPrior(1.1, 1 / 20)
117+
else:
118+
self.lengthscale_prior = lengthscale_prior
119+
120+
if lengthscale_2_prior is None:
121+
self.lengthscale_2_prior = GammaPrior(5, 1 / 20)
122+
else:
123+
self.register_prior(
124+
"lengthscale_2_prior",
125+
lengthscale_2_prior,
126+
lambda: self.lengthscale_2,
127+
lambda v: self._set_lengthscale_2(v),
128+
)
129+
130+
if lengthscale_2_constraint is None:
131+
lengthscale_2_constraint = Positive()
132+
133+
self.register_parameter(
134+
name="raw_power",
135+
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)),
136+
)
137+
self.register_parameter(
138+
name="raw_lengthscale_2",
139+
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)),
140+
)
141+
142+
if power_prior is not None:
143+
self.register_prior(
144+
"power_prior",
145+
power_prior,
146+
lambda: self.power,
147+
lambda v: self._set_power(v),
148+
)
149+
self.nu = nu
150+
self.register_constraint("raw_lengthscale_2", lengthscale_2_constraint)
151+
self.register_constraint("raw_power", power_constraint)
152+
153+
@property
154+
def power(self) -> torch.Tensor:
155+
return self.raw_power_constraint.transform(self.raw_power)
156+
157+
@power.setter
158+
def power(self, value: torch.Tensor) -> None:
159+
self._set_power(value)
160+
161+
def _set_power(self, value: torch.Tensor) -> None:
162+
if not torch.is_tensor(value):
163+
value = torch.as_tensor(value).to(self.raw_power)
164+
self.initialize(raw_power=self.raw_power_constraint.inverse_transform(value))
165+
166+
@property
167+
def lengthscale_2(self) -> torch.Tensor:
168+
return self.raw_lengthscale_2_constraint.transform(self.raw_lengthscale_2)
169+
170+
@lengthscale_2.setter
171+
def lengthscale_2(self, value: torch.Tensor) -> None:
172+
self._set_lengthscale_2(value)
173+
174+
def _set_lengthscale_2(self, value: torch.Tensor) -> None:
175+
if not torch.is_tensor(value):
176+
value = torch.as_tensor(value).to(self.raw_lengthscale_2)
177+
self.initialize(
178+
raw_lengthscale_2=self.raw_lengthscale_2_constraint.inverse_transform(value)
179+
)
180+
181+
def forward(self, x1: torch.Tensor, x2: torch.Tensor, **params) -> torch.Tensor:
182+
m = self.train_iteration_fidelity + self.train_data_fidelity
183+
power = self.power.view(*self.batch_shape, 1, 1)
184+
active_dimsM = list(range(x1.size()[-1] - m))
185+
covar_module_1 = MaternKernel(
186+
nu=self.nu,
187+
batch_shape=self.batch_shape,
188+
lengthscale_prior=self.lengthscale_prior,
189+
active_dims=active_dimsM,
190+
ard_num_dims=x1.shape[-1] - m,
191+
)
192+
covar_module_2 = MaternKernel(
193+
nu=self.nu,
194+
batch_shape=self.batch_shape,
195+
lengthscale_prior=self.lengthscale_2_prior,
196+
active_dims=active_dimsM,
197+
ard_num_dims=x1.shape[-1] - m,
198+
)
199+
covar_0 = covar_module_1(x1, x2)
200+
x11_ = x1[..., -1].unsqueeze(-1)
201+
x21t_ = x2[..., -1].unsqueeze(-1).transpose(-1, -2)
202+
covar_1 = covar_module_2(x1, x2)
203+
if self.train_iteration_fidelity and self.train_data_fidelity:
204+
covar_2 = covar_module_2(x1, x2)
205+
covar_3 = covar_module_2(x1, x2)
206+
x12_ = x1[..., -2].unsqueeze(-1)
207+
x22t_ = x2[..., -2].unsqueeze(-1).transpose(-1, -2)
208+
res = (
209+
covar_0
210+
+ (1 - x12_) * (1 - x22t_) * (1 + x12_ * x22t_).pow(power) * covar_1
211+
+ (1 - x12_)
212+
* (1 - x22t_)
213+
* (1 - x11_)
214+
* (1 - x21t_)
215+
* (1 + torch.matmul(x1[..., -2:], x2[..., -2:].transpose(-1, -2))).pow(
216+
power
217+
)
218+
* covar_2
219+
+ (1 - x11_) * (1 - x21t_) * (1 + x11_ * x21t_).pow(power) * covar_3
220+
)
221+
else:
222+
res = (
223+
covar_0
224+
+ (1 - x11_) * (1 - x21t_) * (1 + x11_ * x21t_).pow(power) * covar_1
225+
)
226+
return res
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4+
5+
import unittest
6+
7+
import torch
8+
from botorch.exceptions import UnsupportedError
9+
from botorch.models.fidelity_kernels.linear_truncated_fidelity import (
10+
LinearTruncatedFidelityKernel,
11+
)
12+
from gpytorch.kernels.matern_kernel import MaternKernel
13+
from gpytorch.priors.torch_priors import GammaPrior, NormalPrior
14+
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase
15+
16+
17+
class TestLinearTruncatedFidelityKernel(unittest.TestCase, BaseKernelTestCase):
18+
def create_kernel_no_ard(self, **kwargs):
19+
return LinearTruncatedFidelityKernel(**kwargs)
20+
21+
def create_data_no_batch(self):
22+
return torch.rand(50, 10)
23+
24+
def create_data_single_batch(self):
25+
return torch.rand(2, 50, 3)
26+
27+
def create_data_double_batch(self):
28+
return torch.rand(3, 2, 50, 3)
29+
30+
def test_compute_linear_truncated_kernel_no_batch(self):
31+
x1 = torch.tensor([1, 0.1, 0.2, 2, 0.3, 0.4], dtype=torch.float).view(2, 3)
32+
x2 = torch.tensor([3, 0.5, 0.6, 4, 0.7, 0.8], dtype=torch.float).view(2, 3)
33+
t_1 = torch.tensor([0.3584, 0.1856, 0.2976, 0.1584], dtype=torch.float).view(
34+
2, 2
35+
)
36+
for nu in {0.5, 1.5, 2.5}:
37+
for train_data_fidelity in {False, True}:
38+
kernel = LinearTruncatedFidelityKernel(nu=nu)
39+
kernel.power = 1
40+
kernel.train_data_fidelity = train_data_fidelity
41+
if train_data_fidelity:
42+
active_dimsM = [0]
43+
t_2 = torch.tensor(
44+
[0.4725, 0.2889, 0.4025, 0.2541], dtype=torch.float
45+
).view(2, 2)
46+
t_3 = torch.tensor(
47+
[0.1685, 0.0531, 0.1168, 0.0386], dtype=torch.float
48+
).view(2, 2)
49+
t = 1 + t_1 + t_2 + t_3
50+
else:
51+
active_dimsM = [0, 1]
52+
t = 1 + t_1
53+
matern_ker = MaternKernel(nu=nu, active_dims=active_dimsM)
54+
matern_term = matern_ker(x1, x2).evaluate()
55+
actual = t * matern_term
56+
res = kernel(x1, x2).evaluate()
57+
self.assertLess(torch.norm(res - actual), 1e-4)
58+
59+
def test_compute_linear_truncated_kernel_with_batch(self):
60+
x1 = torch.tensor(
61+
[1, 0.1, 0.2, 3, 0.3, 0.4, 5, 0.5, 0.6, 7, 0.7, 0.8], dtype=torch.float
62+
).view(2, 2, 3)
63+
x2 = torch.tensor(
64+
[2, 0.8, 0.7, 4, 0.6, 0.5, 6, 0.4, 0.3, 8, 0.2, 0.1], dtype=torch.float
65+
).view(2, 2, 3)
66+
t_1 = torch.tensor(
67+
[0.2736, 0.44, 0.2304, 0.36, 0.3304, 0.3816, 0.1736, 0.1944],
68+
dtype=torch.float,
69+
).view(2, 2, 2)
70+
batch_shape = torch.Size([2])
71+
for nu in {0.5, 1.5, 2.5}:
72+
for train_data_fidelity in {False, True}:
73+
kernel = LinearTruncatedFidelityKernel(nu=nu, batch_shape=batch_shape)
74+
kernel.power = 1
75+
kernel.train_data_fidelity = train_data_fidelity
76+
if train_data_fidelity:
77+
active_dimsM = [0]
78+
t_2 = torch.tensor(
79+
[0.0527, 0.167, 0.0383, 0.1159, 0.1159, 0.167, 0.0383, 0.0527],
80+
dtype=torch.float,
81+
).view(2, 2, 2)
82+
t_3 = torch.tensor(
83+
[0.1944, 0.3816, 0.1736, 0.3304, 0.36, 0.44, 0.2304, 0.2736],
84+
dtype=torch.float,
85+
).view(2, 2, 2)
86+
t = 1 + t_1 + t_2 + t_3
87+
else:
88+
active_dimsM = [0, 1]
89+
t = 1 + t_1
90+
91+
matern_ker = MaternKernel(
92+
nu=nu, active_dims=active_dimsM, batch_shape=batch_shape
93+
)
94+
matern_term = matern_ker(x1, x2).evaluate()
95+
actual = t * matern_term
96+
res = kernel(x1, x2).evaluate()
97+
self.assertLess(torch.norm(res - actual), 1e-4)
98+
99+
def test_initialize_lengthscale_prior(self):
100+
kernel = LinearTruncatedFidelityKernel()
101+
self.assertTrue(isinstance(kernel.lengthscale_prior, GammaPrior))
102+
kernel.lengthscale_prior = NormalPrior(1, 1)
103+
self.assertTrue(isinstance(kernel.lengthscale_prior, NormalPrior))
104+
self.assertTrue(isinstance(kernel.lengthscale_2_prior, GammaPrior))
105+
kernel.lengthscale_2_prior = NormalPrior(1, 1)
106+
self.assertTrue(isinstance(kernel.lengthscale_2_prior, NormalPrior))
107+
kernel2 = LinearTruncatedFidelityKernel(lengthscale_prior=NormalPrior(1, 1))
108+
self.assertTrue(isinstance(kernel2.lengthscale_prior, NormalPrior))
109+
kernel2 = LinearTruncatedFidelityKernel(lengthscale_2_prior=NormalPrior(1, 1))
110+
self.assertTrue(isinstance(kernel2.lengthscale_2_prior, NormalPrior))
111+
112+
def test_initialize_power_prior(self):
113+
kernel = LinearTruncatedFidelityKernel(power_prior=NormalPrior(1, 1))
114+
self.assertTrue(isinstance(kernel.power_prior, NormalPrior))
115+
116+
def test_initialize_power(self):
117+
kernel = LinearTruncatedFidelityKernel()
118+
kernel.initialize(power=1)
119+
actual_value = torch.tensor(1, dtype=torch.float).view_as(kernel.power)
120+
self.assertLess(torch.norm(kernel.power - actual_value), 1e-5)
121+
122+
def test_initialize_power_batch(self):
123+
kernel = LinearTruncatedFidelityKernel(batch_shape=torch.Size([2]))
124+
power_init = torch.tensor([1, 2], dtype=torch.float)
125+
kernel.initialize(power=power_init)
126+
actual_value = power_init.view_as(kernel.power)
127+
self.assertLess(torch.norm(kernel.power - actual_value), 1e-5)
128+
129+
def test_raise_fidelity_error(self):
130+
kernel = LinearTruncatedFidelityKernel
131+
with self.assertRaises(UnsupportedError):
132+
(kernel(train_iteration_fidelity=False, train_data_fidelity=False))
133+
134+
def test_raise_matern_error(self):
135+
with self.assertRaises(ValueError):
136+
LinearTruncatedFidelityKernel(nu=1)
137+
138+
def test_initialize_lengthscale_2(self):
139+
kernel = LinearTruncatedFidelityKernel()
140+
kernel.initialize(lengthscale_2=1)
141+
actual_value = torch.tensor(1, dtype=torch.float).view_as(kernel.lengthscale_2)
142+
self.assertLess(torch.norm(kernel.lengthscale_2 - actual_value), 1e-5)
143+
144+
def test_initialize_lengthscale_2_batch(self):
145+
kernel = LinearTruncatedFidelityKernel(batch_shape=torch.Size([2]))
146+
lengthscale_2_init = torch.tensor([1, 2], dtype=torch.float)
147+
kernel.initialize(lengthscale_2=lengthscale_2_init)
148+
actual_value = lengthscale_2_init.view_as(kernel.lengthscale_2)
149+
self.assertLess(torch.norm(kernel.lengthscale_2 - actual_value), 1e-5)

0 commit comments

Comments
 (0)