Skip to content

Commit 98503e4

Browse files
roussel-ryanfacebook-github-bot
authored andcommitted
Buxfix for Proximal acquisition function wrapper for negative base acquisition functions (#1447)
Summary: ## Motivation This PR fixes a major issue when using the ```ProximalAcquisitionFunction``` with base acquisition functions that are not strictly positive. This PR fixes it by applying a Softplus transformation to the base acquisition function values (using optional beta = 1,0 argument) before multiplying by proximal weighting. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #1447 Test Plan: Tests have been updated with correct (softplus transformed) values. Reviewed By: Balandat Differential Revision: D40238091 Pulled By: saitcakmak fbshipit-source-id: 7529114c77bd9a3634d2ccc1aeeb333452122b80
1 parent ab50b85 commit 98503e4

File tree

2 files changed

+78
-10
lines changed

2 files changed

+78
-10
lines changed

botorch/acquisition/proximal.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,19 @@
2727

2828
class ProximalAcquisitionFunction(AcquisitionFunction):
2929
"""A wrapper around AcquisitionFunctions to add proximal weighting of the
30-
acquisition function. Acquisition function is weighted via a squared exponential
31-
centered at the last training point, with varying lengthscales corresponding to
32-
`proximal_weights`. Can only be used with acquisition functions based on single
33-
batch models.
30+
acquisition function. The acquisition function is
31+
weighted via a squared exponential centered at the last training point,
32+
with varying lengthscales corresponding to `proximal_weights`. Can only be used
33+
with acquisition functions based on single batch models. Acquisition functions
34+
must be positive or `beta` must be specified to apply a SoftPlus transform before
35+
proximal weighting.
3436
3537
Small values of `proximal_weights` corresponds to strong biasing towards recently
3638
observed points, which smoothes optimization with a small potential decrese in
3739
convergence rate.
3840
41+
42+
3943
Example:
4044
>>> model = SingleTaskGP(train_X, train_Y)
4145
>>> EI = ExpectedImprovement(model, best_f=0.0)
@@ -48,7 +52,8 @@ def __init__(
4852
self,
4953
acq_function: AcquisitionFunction,
5054
proximal_weights: Tensor,
51-
transformed_weighting: bool = True,
55+
transformed_weighting: Optional[bool] = True,
56+
beta: Optional[float] = None,
5257
) -> None:
5358
r"""Derived Acquisition Function weighted by proximity to recently
5459
observed point.
@@ -62,6 +67,8 @@ def __init__(
6267
the transformed input space given by
6368
`acq_function.model.input_transform` (if available), otherwise
6469
proximal weights are applied in real input space.
70+
beta: If not None, apply a softplus transform to the base acquisition
71+
function, allows negative base acquisition function values.
6572
"""
6673
Module.__init__(self)
6774

@@ -79,6 +86,9 @@ def __init__(
7986
self.register_buffer(
8087
"transformed_weighting", torch.tensor(transformed_weighting)
8188
)
89+
90+
self.register_buffer("beta", None if beta is None else torch.tensor(beta))
91+
8292
_validate_model(model, proximal_weights)
8393

8494
@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
@@ -127,7 +137,20 @@ def forward(self, X: Tensor) -> Tensor:
127137

128138
M = torch.linalg.norm(diff / self.proximal_weights, dim=-1) ** 2
129139
proximal_acq_weight = torch.exp(-0.5 * M)
130-
return self.acq_func(X) * proximal_acq_weight.flatten()
140+
141+
base_acqf = self.acq_func(X)
142+
if self.beta is None:
143+
if torch.any(base_acqf < 0):
144+
raise RuntimeError(
145+
"Cannot use proximal biasing for negative "
146+
"acquisition function values, set a value for beta to "
147+
"fix this with a softplus transform"
148+
)
149+
150+
else:
151+
base_acqf = torch.nn.functional.softplus(base_acqf, beta=self.beta)
152+
153+
return base_acqf * proximal_acq_weight.flatten()
131154

132155

133156
def _validate_model(model: Model, proximal_weights: Tensor) -> None:

test/acquisition/test_proximal.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ def forward(self, X):
3636
pass
3737

3838

39+
class NegativeAcquisitionFunction(AcquisitionFunction):
40+
def forward(self, X):
41+
return torch.ones(*X.shape[:-1]) * -1.0
42+
43+
3944
class TestProximalAcquisitionFunction(BotorchTestCase):
4045
def test_proximal(self):
4146
for dtype in (torch.float, torch.double):
@@ -68,6 +73,7 @@ def test_proximal(self):
6873
transformed_weighting=transformed_weighting,
6974
)
7075

76+
# softplus transformed value of the acquisition function
7177
ei = EI(test_X)
7278

7379
# modify last_X/test_X depending on transformed_weighting
@@ -84,7 +90,34 @@ def test_proximal(self):
8490

8591
ei_prox = EI_prox(test_X)
8692
self.assertTrue(torch.allclose(ei_prox, ei * test_prox_weight))
87-
self.assertTrue(ei_prox.shape == torch.Size([1]))
93+
self.assertEqual(ei_prox.shape, torch.Size([1]))
94+
95+
# test with beta specified
96+
EI_prox_beta = ProximalAcquisitionFunction(
97+
EI,
98+
proximal_weights=proximal_weights,
99+
transformed_weighting=transformed_weighting,
100+
beta=1.0,
101+
)
102+
103+
# SoftPlus transformed value of the acquisition function
104+
ei = torch.nn.functional.softplus(EI(test_X), beta=1.0)
105+
106+
# modify last_X/test_X depending on transformed_weighting
107+
proximal_test_X = test_X.clone()
108+
if transformed_weighting:
109+
if input_transform is not None:
110+
last_X = input_transform(train_X[-1])
111+
proximal_test_X = input_transform(test_X)
112+
113+
mv_normal = MultivariateNormal(last_X, torch.diag(proximal_weights))
114+
test_prox_weight = torch.exp(
115+
mv_normal.log_prob(proximal_test_X) - mv_normal.log_prob(last_X)
116+
)
117+
118+
ei_prox_beta = EI_prox_beta(test_X)
119+
self.assertTrue(torch.allclose(ei_prox_beta, ei * test_prox_weight))
120+
self.assertEqual(ei_prox_beta.shape, torch.Size([1]))
88121

89122
# test t-batch with broadcasting
90123
test_X = torch.rand(4, 1, 3, device=self.device, dtype=dtype)
@@ -104,7 +137,7 @@ def test_proximal(self):
104137
self.assertTrue(
105138
torch.allclose(ei_prox, ei * test_prox_weight.flatten())
106139
)
107-
self.assertTrue(ei_prox.shape == torch.Size([4]))
140+
self.assertEqual(ei_prox.shape, torch.Size([4]))
108141

109142
# test q-based MC acquisition function
110143
qEI = qExpectedImprovement(model, best_f=0.0)
@@ -133,6 +166,18 @@ def test_proximal(self):
133166
)
134167
self.assertEqual(qei_prox.shape, torch.Size([4]))
135168

169+
# test acquisition function with
170+
# negative values w/o SoftPlus transform specified
171+
negative_acqf = NegativeAcquisitionFunction(model)
172+
bad_neg_prox = ProximalAcquisitionFunction(
173+
negative_acqf, proximal_weights=proximal_weights
174+
)
175+
176+
with self.assertRaisesRegex(
177+
RuntimeError, "Cannot use proximal biasing for negative"
178+
):
179+
bad_neg_prox(test_X)
180+
136181
# test gradient
137182
test_X = torch.rand(
138183
1, 3, device=self.device, dtype=dtype, requires_grad=True
@@ -228,7 +273,7 @@ def test_proximal_model_list(self):
228273
ei_prox = EI_prox(test_X)
229274

230275
self.assertTrue(torch.allclose(ei_prox, ei * test_prox_weight))
231-
self.assertTrue(ei_prox.shape == torch.Size([1]))
276+
self.assertEqual(ei_prox.shape, torch.Size([1]))
232277

233278
# test MC acquisition function
234279
qEI = qExpectedImprovement(model, best_f=0.0, objective=mc_linear_objective)
@@ -245,7 +290,7 @@ def test_proximal_model_list(self):
245290

246291
qei_prox = qEI_prox(test_X)
247292
self.assertTrue(torch.allclose(qei_prox, qei * test_prox_weight.flatten()))
248-
self.assertTrue(qei_prox.shape == torch.Size([4]))
293+
self.assertEqual(qei_prox.shape, torch.Size([4]))
249294

250295
# test gradient
251296
test_X = torch.rand(

0 commit comments

Comments
 (0)