Skip to content

Commit c5ec613

Browse files
Abbas Kazerounifacebook-github-bot
authored andcommitted
Moving penalized acqfn from botorch_fb to botorch (#585)
Summary: Pull Request resolved: #585 Just moved the development for penalized acqfn from botorch_fb to botorch to push it to the OSS. Reviewed By: Balandat Differential Revision: D24508442 fbshipit-source-id: 54f0884e8e5a86296c6d0e58cc913bbf46323dbb
1 parent 4033c16 commit c5ec613

File tree

3 files changed

+308
-0
lines changed

3 files changed

+308
-0
lines changed

botorch/acquisition/penalized.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
Modules to add regularization to acquisition functions.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import math
14+
from typing import List, Optional
15+
16+
import torch
17+
from botorch.acquisition.acquisition import AcquisitionFunction
18+
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
19+
from botorch.exceptions import UnsupportedError
20+
from torch import Tensor
21+
22+
23+
class L2Penalty(torch.nn.Module):
24+
r"""L2 penalty class to be added to any arbitrary acquisition function."""
25+
26+
def __init__(self, init_point: Tensor):
27+
r"""Initializing L2 regularization.
28+
29+
Args:
30+
init_point: The "1 x dim" reference point against which
31+
we want to regularize.
32+
"""
33+
super().__init__()
34+
self.init_point = init_point
35+
36+
def forward(self, X: Tensor) -> Tensor:
37+
r"""
38+
Args:
39+
X: A "batch_shape x q x dim" representing the points to be evaluated.
40+
41+
Returns:
42+
A tensor of size "batch_shape" representing the acqfn for each q-batch.
43+
"""
44+
regularization_term = (
45+
torch.norm((X - self.init_point), p=2, dim=-1).max(dim=-1).values ** 2
46+
)
47+
return regularization_term
48+
49+
50+
class GaussianPenalty(torch.nn.Module):
51+
r"""Gaussian penalty class to be added to any arbitrary acquisition function."""
52+
53+
def __init__(self, init_point: Tensor, sigma: float):
54+
r"""Initializing Gaussian regularization.
55+
56+
Args:
57+
init_point: The "1 x dim" reference point against which
58+
we want to regularize.
59+
sigma: The parameter used in gaussian function.
60+
"""
61+
super().__init__()
62+
self.init_point = init_point
63+
self.sigma = sigma
64+
65+
def forward(self, X: Tensor) -> Tensor:
66+
r"""
67+
Args:
68+
X: A "batch_shape x q x dim" representing the points to be evaluated.
69+
70+
Returns:
71+
A tensor of size "batch_shape" representing the acqfn for each q-batch.
72+
"""
73+
sq_diff = torch.norm((X - self.init_point), p=2, dim=-1) ** 2
74+
pdf = torch.exp(sq_diff / 2 / self.sigma ** 2)
75+
regularization_term = pdf.max(dim=-1).values
76+
return regularization_term
77+
78+
79+
class GroupLassoPenalty(torch.nn.Module):
80+
r"""Group lasso penalty class to be added to any arbitrary acquisition function."""
81+
82+
def __init__(self, init_point: Tensor, groups: List[List[int]]):
83+
r"""Initializing Group-Lasso regularization.
84+
85+
Args:
86+
init_point: The "1 x dim" reference point against which we want
87+
to regularize.
88+
groups: Groups of indices used in group lasso.
89+
"""
90+
super().__init__()
91+
self.init_point = init_point
92+
self.groups = groups
93+
94+
def forward(self, X: Tensor) -> Tensor:
95+
r"""
96+
X should be batch_shape x 1 x dim tensor. Evaluation for q-batch is not
97+
implemented yet.
98+
"""
99+
if X.shape[-2] != 1:
100+
raise NotImplementedError(
101+
"group-lasso has not been implemented for q>1 yet."
102+
)
103+
104+
regularization_term = group_lasso_regularizer(
105+
X=X.squeeze(-2) - self.init_point, groups=self.groups
106+
)
107+
return regularization_term
108+
109+
110+
class PenalizedAcquisitionFunction(AcquisitionFunction):
111+
r"""Single-outcome acquisition function regularized by the given penalty.
112+
113+
The usage is similar to:
114+
raw_acqf = NoisyExpectedImprovement(...)
115+
penalty = GroupLassoPenalty(...)
116+
acqf = PenalizedAcquisitionFunction(raw_acqf, penalty)
117+
"""
118+
119+
def __init__(
120+
self,
121+
raw_acqf: AcquisitionFunction,
122+
penalty_func: torch.nn.Module,
123+
regularization_parameter: float,
124+
) -> None:
125+
r"""Initializing Group-Lasso regularization.
126+
127+
Args:
128+
raw_acqf: The raw acquisition function that is going to be regularized.
129+
penalty_func: The regularization function.
130+
regularization_parameter: Regularization parameter used in optimization.
131+
"""
132+
super().__init__(model=raw_acqf.model)
133+
self.raw_acqf = raw_acqf
134+
self.penalty_func = penalty_func
135+
self.regularization_parameter = regularization_parameter
136+
137+
def forward(self, X: Tensor) -> Tensor:
138+
raw_value = self.raw_acqf(X=X)
139+
penalty_term = self.penalty_func(X)
140+
return raw_value - self.regularization_parameter * penalty_term
141+
142+
@property
143+
def X_pending(self) -> Optional[Tensor]:
144+
return self.raw_acqf.X_pending
145+
146+
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
147+
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
148+
self.raw_acqf.set_X_pending(X_pending=X_pending)
149+
else:
150+
raise UnsupportedError(
151+
"The raw acquisition function is Analytic and does not account "
152+
"for X_pending yet."
153+
)
154+
155+
156+
def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
157+
r"""Computes the group lasso regularization function for the given point.
158+
159+
Args:
160+
X: A bxd tensor representing the points to evaluate the regularization at.
161+
groups: List of indices of different groups.
162+
163+
Returns:
164+
Computed group lasso norm of at the given points.
165+
"""
166+
return torch.sum(
167+
torch.stack(
168+
[math.sqrt(len(g)) * torch.norm(X[..., g], p=2, dim=-1) for g in groups],
169+
dim=-1,
170+
),
171+
dim=-1,
172+
)

sphinx/source/acquisition.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ Fixed Feature Acquisition Function
110110
.. automodule:: botorch.acquisition.fixed_feature
111111
:members:
112112

113+
Penalized Acquisition Function Wrapper
114+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
115+
.. automodule:: botorch.acquisition.penalized
116+
:members:
117+
113118
General Utilities for Acquisition Functions
114119
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
115120
.. automodule:: botorch.acquisition.utils

test/acquisition/test_penalized.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from botorch.acquisition.analytic import ExpectedImprovement
9+
from botorch.acquisition.monte_carlo import qExpectedImprovement
10+
from botorch.acquisition.penalized import (
11+
GaussianPenalty,
12+
GroupLassoPenalty,
13+
L2Penalty,
14+
PenalizedAcquisitionFunction,
15+
group_lasso_regularizer,
16+
)
17+
from botorch.exceptions import UnsupportedError
18+
from botorch.sampling.samplers import IIDNormalSampler
19+
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
20+
21+
22+
class TestL2Penalty(BotorchTestCase):
23+
def test_gaussian_penalty(self):
24+
for dtype in (torch.float, torch.double):
25+
init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype)
26+
l2_module = L2Penalty(init_point=init_point)
27+
28+
# testing a batch of two points
29+
sample_point = torch.tensor(
30+
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype
31+
)
32+
33+
diff_norm_squared = (
34+
torch.norm((sample_point - init_point), p=2, dim=-1) ** 2
35+
)
36+
real_value = diff_norm_squared.max(dim=-1).values
37+
computed_value = l2_module(sample_point)
38+
self.assertEqual(computed_value.item(), real_value.item())
39+
40+
41+
class TestGaussianPenalty(BotorchTestCase):
42+
def test_gaussian_penalty(self):
43+
for dtype in (torch.float, torch.double):
44+
init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype)
45+
sigma = 0.1
46+
gaussian_module = GaussianPenalty(init_point=init_point, sigma=sigma)
47+
48+
# testing a batch of two points
49+
sample_point = torch.tensor(
50+
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype
51+
)
52+
53+
diff_norm_squared = (
54+
torch.norm((sample_point - init_point), p=2, dim=-1) ** 2
55+
)
56+
max_l2_distance = diff_norm_squared.max(dim=-1).values
57+
real_value = torch.exp(max_l2_distance / 2 / sigma ** 2)
58+
computed_value = gaussian_module(sample_point)
59+
self.assertEqual(computed_value.item(), real_value.item())
60+
61+
62+
class TestGroupLassoPenalty(BotorchTestCase):
63+
def test_group_lasso_penalty(self):
64+
for dtype in (torch.float, torch.double):
65+
init_point = torch.tensor([0.5, 0.5, 0.5], device=self.device, dtype=dtype)
66+
groups = [[0, 2], [1]]
67+
group_lasso_module = GroupLassoPenalty(init_point=init_point, groups=groups)
68+
69+
# testing a single point
70+
sample_point = torch.tensor(
71+
[[1.0, 2.0, 3.0]], device=self.device, dtype=dtype
72+
)
73+
real_value = group_lasso_regularizer(
74+
sample_point - init_point, groups
75+
) # torch.tensor([5.105551242828369], device=self.device, dtype=dtype)
76+
computed_value = group_lasso_module(sample_point)
77+
self.assertEqual(computed_value.item(), real_value.item())
78+
79+
# testing unsupported input dim: X.shape[-2] > 1
80+
sample_point_2 = torch.tensor(
81+
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype
82+
)
83+
with self.assertRaises(NotImplementedError):
84+
group_lasso_module(sample_point_2)
85+
86+
87+
class TestPenalizedAcquisitionFunction(BotorchTestCase):
88+
def test_penalized_acquisition_function(self):
89+
for dtype in (torch.float, torch.double):
90+
mock_model = MockModel(
91+
MockPosterior(mean=torch.tensor([1.0]), variance=torch.tensor([1.0]))
92+
)
93+
init_point = torch.tensor([0.5, 0.5, 0.5], device=self.device, dtype=dtype)
94+
groups = [[0, 2], [1]]
95+
raw_acqf = ExpectedImprovement(model=mock_model, best_f=1.0)
96+
penalty = GroupLassoPenalty(init_point=init_point, groups=groups)
97+
lmbda = 0.1
98+
acqf = PenalizedAcquisitionFunction(
99+
raw_acqf=raw_acqf, penalty_func=penalty, regularization_parameter=lmbda
100+
)
101+
102+
sample_point = torch.tensor(
103+
[[1.0, 2.0, 3.0]], device=self.device, dtype=dtype
104+
)
105+
raw_value = raw_acqf(sample_point)
106+
penalty_value = penalty(sample_point)
107+
real_value = raw_value - lmbda * penalty_value
108+
computed_value = acqf(sample_point)
109+
self.assertTrue(torch.equal(real_value, computed_value))
110+
111+
# testing X_pending for analytic raw_acqfn (EI)
112+
X_pending = torch.tensor([0.1, 0.2, 0.3], device=self.device, dtype=dtype)
113+
with self.assertRaises(UnsupportedError):
114+
acqf.set_X_pending(X_pending)
115+
116+
# testing X_pending for non-analytic raw_acqfn (EI)
117+
sampler = IIDNormalSampler(num_samples=2)
118+
raw_acqf_2 = qExpectedImprovement(
119+
model=mock_model, best_f=0, sampler=sampler
120+
)
121+
init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype)
122+
l2_module = L2Penalty(init_point=init_point)
123+
acqf_2 = PenalizedAcquisitionFunction(
124+
raw_acqf=raw_acqf_2,
125+
penalty_func=l2_module,
126+
regularization_parameter=lmbda,
127+
)
128+
129+
X_pending = torch.tensor([0.1, 0.2, 0.3], device=self.device, dtype=dtype)
130+
acqf_2.set_X_pending(X_pending)
131+
self.assertTrue(torch.equal(acqf_2.X_pending, X_pending))

0 commit comments

Comments
 (0)