Skip to content

Commit 5fbe736

Browse files
wjmaddoxfacebook-github-bot
authored andcommitted
Backprop wrt candidates in gen_candidates_torch (#766)
Summary: ## Motivation Bugfix for #765. The current error was occurring when trying to optimize qKnowledgeGradient with gen_candidates_torch. The fix is to backprop only wrt candidates like what occurs in `gen_candidates_scipy`. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/master/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #766 Test Plan: Not sure how to setup a unit test here. ## Related PRs N/A. Reviewed By: Balandat Differential Revision: D29401945 Pulled By: wjmaddox fbshipit-source-id: a00e0b93b41206487b8b5ecb39bd19116d533b58
1 parent 271d4c0 commit 5fbe736

File tree

2 files changed

+76
-40
lines changed

2 files changed

+76
-40
lines changed

botorch/generation/gen.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,6 @@ def gen_candidates_torch(
276276
bayes_optimizer = optimizer(
277277
params=[clamped_candidates], lr=options.get("lr", 0.025)
278278
)
279-
param_trajectory: Dict[str, List[Tensor]] = {"candidates": []}
280-
loss_trajectory: List[float] = []
281279
i = 0
282280
stop = False
283281
stopping_criterion = ExpMAStoppingCriterion(
@@ -288,19 +286,18 @@ def gen_candidates_torch(
288286
loss = -acquisition_function(clamped_candidates).sum()
289287
if verbose:
290288
print("Iter: {} - Value: {:.3f}".format(i, -(loss.item())))
291-
loss_trajectory.append(loss.item())
292-
param_trajectory["candidates"].append(clamped_candidates.clone())
293289

294290
def closure():
295291
bayes_optimizer.zero_grad()
296-
loss.backward()
292+
output_grad = torch.autograd.grad(loss, clamped_candidates)[0]
293+
clamped_candidates.grad = output_grad
297294
return loss
298295

299296
bayes_optimizer.step(closure)
300297
with torch.no_grad():
301298
clamped_candidates = columnwise_clamp(
302299
X=clamped_candidates, lower=lower_bounds, upper=upper_bounds
303-
)
300+
).requires_grad_(True)
304301
stop = stopping_criterion.evaluate(fvals=loss.detach())
305302
clamped_candidates = columnwise_clamp(
306303
X=clamped_candidates,

test/generation/test_gen.py

Lines changed: 73 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from unittest import mock
1010

1111
import torch
12-
from botorch.acquisition import qExpectedImprovement
12+
from botorch.acquisition import qExpectedImprovement, qKnowledgeGradient
1313
from botorch.exceptions.warnings import OptimizationWarning
1414
from botorch.fit import fit_gpytorch_model
1515
from botorch.generation.gen import (
@@ -72,15 +72,28 @@ def test_gen_candidates(self, gen_candidates=gen_candidates_scipy, options=None)
7272
options = {**options, "maxiter": 5}
7373
for double in (True, False):
7474
self._setUp(double=double)
75-
qEI = qExpectedImprovement(self.model, best_f=self.f_best)
76-
candidates, _ = gen_candidates(
77-
initial_conditions=self.initial_conditions,
78-
acquisition_function=qEI,
79-
lower_bounds=0,
80-
upper_bounds=1,
81-
options=options or {},
82-
)
83-
self.assertTrue(-EPS <= candidates <= 1 + EPS)
75+
acqfs = [
76+
qExpectedImprovement(self.model, best_f=self.f_best),
77+
qKnowledgeGradient(
78+
self.model, num_fantasies=4, current_value=self.f_best
79+
),
80+
]
81+
for acqf in acqfs:
82+
ics = self.initial_conditions
83+
if isinstance(acqf, qKnowledgeGradient):
84+
ics = ics.repeat(5, 1)
85+
86+
candidates, _ = gen_candidates(
87+
initial_conditions=ics,
88+
acquisition_function=acqf,
89+
lower_bounds=0,
90+
upper_bounds=1,
91+
options=options or {},
92+
)
93+
if isinstance(acqf, qKnowledgeGradient):
94+
candidates = acqf.extract_candidates(candidates)
95+
96+
self.assertTrue(-EPS <= candidates <= 1 + EPS)
8497

8598
def test_gen_candidates_torch(self):
8699
self.test_gen_candidates(
@@ -96,18 +109,30 @@ def test_gen_candidates_with_none_fixed_features(
96109
options = {**options, "maxiter": 5}
97110
for double in (True, False):
98111
self._setUp(double=double, expand=True)
99-
qEI = qExpectedImprovement(self.model, best_f=self.f_best)
100-
candidates, _ = gen_candidates(
101-
initial_conditions=self.initial_conditions,
102-
acquisition_function=qEI,
103-
lower_bounds=0,
104-
upper_bounds=1,
105-
fixed_features={1: None},
106-
options=options or {},
107-
)
108-
candidates = candidates.squeeze(0)
109-
self.assertTrue(-EPS <= candidates[0] <= 1 + EPS)
110-
self.assertTrue(candidates[1].item() == 1.0)
112+
acqfs = [
113+
qExpectedImprovement(self.model, best_f=self.f_best),
114+
qKnowledgeGradient(
115+
self.model, num_fantasies=4, current_value=self.f_best
116+
),
117+
]
118+
for acqf in acqfs:
119+
ics = self.initial_conditions
120+
if isinstance(acqf, qKnowledgeGradient):
121+
ics = ics.repeat(5, 1)
122+
123+
candidates, _ = gen_candidates(
124+
initial_conditions=ics,
125+
acquisition_function=acqf,
126+
lower_bounds=0,
127+
upper_bounds=1,
128+
fixed_features={1: None},
129+
options=options or {},
130+
)
131+
if isinstance(acqf, qKnowledgeGradient):
132+
candidates = acqf.extract_candidates(candidates)
133+
candidates = candidates.squeeze(0)
134+
self.assertTrue(-EPS <= candidates[0] <= 1 + EPS)
135+
self.assertTrue(candidates[1].item() == 1.0)
111136

112137
def test_gen_candidates_torch_with_none_fixed_features(self):
113138
self.test_gen_candidates_with_none_fixed_features(
@@ -121,18 +146,32 @@ def test_gen_candidates_with_fixed_features(
121146
options = {**options, "maxiter": 5}
122147
for double in (True, False):
123148
self._setUp(double=double, expand=True)
124-
qEI = qExpectedImprovement(self.model, best_f=self.f_best)
125-
candidates, _ = gen_candidates(
126-
initial_conditions=self.initial_conditions,
127-
acquisition_function=qEI,
128-
lower_bounds=0,
129-
upper_bounds=1,
130-
fixed_features={1: 0.25},
131-
options=options,
132-
)
133-
candidates = candidates.squeeze(0)
134-
self.assertTrue(-EPS <= candidates[0] <= 1 + EPS)
135-
self.assertTrue(candidates[1].item() == 0.25)
149+
acqfs = [
150+
qExpectedImprovement(self.model, best_f=self.f_best),
151+
qKnowledgeGradient(
152+
self.model, num_fantasies=4, current_value=self.f_best
153+
),
154+
]
155+
for acqf in acqfs:
156+
ics = self.initial_conditions
157+
if isinstance(acqf, qKnowledgeGradient):
158+
ics = ics.repeat(5, 1)
159+
160+
candidates, _ = gen_candidates(
161+
initial_conditions=ics,
162+
acquisition_function=acqf,
163+
lower_bounds=0,
164+
upper_bounds=1,
165+
fixed_features={1: 0.25},
166+
options=options,
167+
)
168+
169+
if isinstance(acqf, qKnowledgeGradient):
170+
candidates = acqf.extract_candidates(candidates)
171+
172+
candidates = candidates.squeeze(0)
173+
self.assertTrue(-EPS <= candidates[0] <= 1 + EPS)
174+
self.assertTrue(candidates[1].item() == 0.25)
136175

137176
def test_gen_candidates_scipy_with_fixed_features_inequality_constraints(self):
138177
options = {"maxiter": 5}

0 commit comments

Comments
 (0)