Skip to content

Commit 4079164

Browse files
James Wilsonfacebook-github-bot
authored andcommitted
Fix gen_candidates_torch
Summary: Fix a bug where gradients are always zero after the first iteration of `gen_candidates_torch`. Also, removed `verbose` in favor of a generic callback and did some minor cleanup. Reviewed By: Balandat Differential Revision: D34436282 fbshipit-source-id: 7be733b8db120771d63a5b91ca674813bfc5a832
1 parent 8cd792b commit 4079164

File tree

1 file changed

+31
-35
lines changed

1 file changed

+31
-35
lines changed

botorch/generation/gen.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from __future__ import annotations
1212

1313
import warnings
14-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
14+
from functools import partial
15+
from typing import Any, Callable, Dict, List, NoReturn, Optional, Tuple, Type, Union
1516

1617
import numpy as np
1718
import torch
@@ -242,7 +243,7 @@ def gen_candidates_torch(
242243
upper_bounds: Optional[Union[float, Tensor]] = None,
243244
optimizer: Type[Optimizer] = torch.optim.Adam,
244245
options: Optional[Dict[str, Union[float, str]]] = None,
245-
verbose: bool = True,
246+
callback: Optional[Callable[[int, Tensor, Tensor], NoReturn]] = None,
246247
fixed_features: Optional[Dict[int, Optional[float]]] = None,
247248
) -> Tuple[Tensor, Tensor]:
248249
r"""Generate a set of candidates using a `torch.optim` optimizer.
@@ -259,7 +260,9 @@ def gen_candidates_torch(
259260
candidate search.
260261
options: Options used to control the optimization. Includes
261262
maxiter: Maximum number of iterations
262-
verbose: If True, provide verbose output.
263+
callback: A callback function accepting the current iteration, loss,
264+
and gradients as arguments. This function is executed after computing
265+
the loss and gradients, but before calling the optimizer.
263266
fixed_features: This is a dictionary of feature indices to values, where
264267
all generated candidates will have features fixed to these values.
265268
If the dictionary value is None, then that feature will just be
@@ -289,7 +292,7 @@ def gen_candidates_torch(
289292

290293
# if there are fixed features we may optimize over a domain of lower dimension
291294
if fixed_features:
292-
_no_fixed_features = _remove_fixed_features_from_optimization(
295+
subproblem = _remove_fixed_features_from_optimization(
293296
fixed_features=fixed_features,
294297
acquisition_function=acquisition_function,
295298
initial_conditions=initial_conditions,
@@ -301,55 +304,48 @@ def gen_candidates_torch(
301304

302305
# call the routine with no fixed_features
303306
clamped_candidates, batch_acquisition = gen_candidates_torch(
304-
initial_conditions=_no_fixed_features.initial_conditions,
305-
acquisition_function=_no_fixed_features.acquisition_function,
306-
lower_bounds=_no_fixed_features.lower_bounds,
307-
upper_bounds=_no_fixed_features.upper_bounds,
307+
initial_conditions=subproblem.initial_conditions,
308+
acquisition_function=subproblem.acquisition_function,
309+
lower_bounds=subproblem.lower_bounds,
310+
upper_bounds=subproblem.upper_bounds,
308311
optimizer=optimizer,
309312
options=options,
310-
verbose=verbose,
313+
callback=callback,
311314
fixed_features=None,
312315
)
313-
clamped_candidates = _no_fixed_features.acquisition_function._construct_X_full(
316+
clamped_candidates = subproblem.acquisition_function._construct_X_full(
314317
clamped_candidates
315318
)
316319
return clamped_candidates, batch_acquisition
317320

318-
clamped_candidates = columnwise_clamp(
319-
X=initial_conditions, lower=lower_bounds, upper=upper_bounds
320-
).requires_grad_(True)
321-
bayes_optimizer = optimizer(
322-
params=[clamped_candidates], lr=options.get("lr", 0.025)
323-
)
321+
_clamp = partial(columnwise_clamp, lower=lower_bounds, upper=upper_bounds)
322+
clamped_candidates = _clamp(initial_conditions).requires_grad_(True)
323+
_optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025))
324+
324325
i = 0
325326
stop = False
326327
stopping_criterion = ExpMAStoppingCriterion(
327328
**_filter_kwargs(ExpMAStoppingCriterion, **options)
328329
)
329330
while not stop:
330331
i += 1
331-
loss = -acquisition_function(clamped_candidates).sum()
332-
if verbose:
333-
print("Iter: {} - Value: {:.3f}".format(i, -(loss.item())))
334-
335-
def closure():
336-
bayes_optimizer.zero_grad()
337-
output_grad = torch.autograd.grad(loss, clamped_candidates)[0]
338-
clamped_candidates.grad = output_grad
332+
with torch.no_grad():
333+
X = _clamp(clamped_candidates).requires_grad_(True)
334+
335+
loss = -acquisition_function(X).sum()
336+
grad = torch.autograd.grad(loss, X)[0]
337+
if callback:
338+
callback(i, loss, grad)
339+
340+
def assign_grad():
341+
_optimizer.zero_grad()
342+
clamped_candidates.grad = grad
339343
return loss
340344

341-
bayes_optimizer.step(closure)
342-
with torch.no_grad():
343-
clamped_candidates = columnwise_clamp(
344-
X=clamped_candidates, lower=lower_bounds, upper=upper_bounds
345-
).requires_grad_(True)
345+
_optimizer.step(assign_grad)
346346
stop = stopping_criterion.evaluate(fvals=loss.detach())
347-
clamped_candidates = columnwise_clamp(
348-
X=clamped_candidates,
349-
lower=lower_bounds,
350-
upper=upper_bounds,
351-
raise_on_violation=True,
352-
)
347+
348+
clamped_candidates = _clamp(clamped_candidates)
353349
with torch.no_grad():
354350
batch_acquisition = acquisition_function(clamped_candidates)
355351

0 commit comments

Comments
 (0)