Skip to content

Commit 271d4c0

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Constrain optimization over unfixed features if fixed_features is passed. (#839)
Summary: Pull Request resolved: #839 This diff implements a change previously discussed in https://www.internalfb.com/diff/D28869203?dst_version_fbid=180392114012471&transaction_fbid=280143027144841. This is a sensible thing to do and will potentially save a lot of compute. Reviewed By: Balandat Differential Revision: D29265778 fbshipit-source-id: e8018a45e17394c865f414df44f9ceafaf61f0b8
1 parent f7a2fb7 commit 271d4c0

File tree

6 files changed

+452
-22
lines changed

6 files changed

+452
-22
lines changed

botorch/generation/gen.py

Lines changed: 91 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010

1111
from __future__ import annotations
1212

13-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
13+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
1414

1515
import numpy as np
1616
import torch
17+
from botorch.acquisition import AcquisitionFunction
18+
from botorch.generation.utils import _remove_fixed_features_from_optimization
1719
from botorch.optim.parameter_constraints import (
1820
_arrayify,
1921
make_scipy_bounds,
@@ -23,13 +25,12 @@
2325
from botorch.optim.utils import _filter_kwargs, columnwise_clamp, fix_features
2426
from scipy.optimize import minimize
2527
from torch import Tensor
26-
from torch.nn import Module
2728
from torch.optim import Optimizer
2829

2930

3031
def gen_candidates_scipy(
3132
initial_conditions: Tensor,
32-
acquisition_function: Module,
33+
acquisition_function: AcquisitionFunction,
3334
lower_bounds: Optional[Union[float, Tensor]] = None,
3435
upper_bounds: Optional[Union[float, Tensor]] = None,
3536
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
@@ -83,9 +84,45 @@ def gen_candidates_scipy(
8384
)
8485
"""
8586
options = options or {}
87+
88+
# REDUCED is used indicate if we are optimizing over a reduced domain dimension
89+
# after considering fixed_features.
90+
# REDUCED mode if fixed_features is not None except for when fixed_features.values()
91+
# contains None and linear constraints are passed.
92+
REDUCED = fixed_features is not None
93+
if inequality_constraints or equality_constraints:
94+
REDUCED = REDUCED and (None not in fixed_features.values())
95+
96+
if REDUCED:
97+
_no_fixed_features = _remove_fixed_features_from_optimization(
98+
fixed_features=fixed_features,
99+
acquisition_function=acquisition_function,
100+
initial_conditions=initial_conditions,
101+
lower_bounds=lower_bounds,
102+
upper_bounds=upper_bounds,
103+
inequality_constraints=inequality_constraints,
104+
equality_constraints=equality_constraints,
105+
)
106+
107+
# call the routine with no fixed_features
108+
clamped_candidates, batch_acquisition = gen_candidates_scipy(
109+
initial_conditions=_no_fixed_features.initial_conditions,
110+
acquisition_function=_no_fixed_features.acquisition_function,
111+
lower_bounds=_no_fixed_features.lower_bounds,
112+
upper_bounds=_no_fixed_features.upper_bounds,
113+
inequality_constraints=_no_fixed_features.inequality_constraints,
114+
equality_constraints=_no_fixed_features.equality_constraints,
115+
options=options,
116+
fixed_features=None,
117+
)
118+
clamped_candidates = _no_fixed_features.acquisition_function._construct_X_full(
119+
clamped_candidates
120+
)
121+
return clamped_candidates, batch_acquisition
122+
86123
clamped_candidates = columnwise_clamp(
87124
X=initial_conditions, lower=lower_bounds, upper=upper_bounds
88-
).requires_grad_(True)
125+
)
89126

90127
shapeX = clamped_candidates.shape
91128
x0 = _arrayify(clamped_candidates.view(-1))
@@ -111,7 +148,7 @@ def f(x):
111148
.contiguous()
112149
.requires_grad_(True)
113150
)
114-
X_fix = fix_features(X=X, fixed_features=fixed_features)
151+
X_fix = fix_features(X, fixed_features=fixed_features)
115152
loss = -acquisition_function(X_fix).sum()
116153
# compute gradient w.r.t. the inputs (does not accumulate in leaves)
117154
gradf = _arrayify(torch.autograd.grad(loss, X)[0].contiguous().view(-1))
@@ -137,20 +174,22 @@ def f(x):
137174
options={k: v for k, v in options.items() if k not in ["method", "callback"]},
138175
)
139176
candidates = fix_features(
140-
X=torch.from_numpy(res.x).to(initial_conditions).view(shapeX).contiguous(),
177+
X=torch.from_numpy(res.x).to(initial_conditions).reshape(shapeX),
141178
fixed_features=fixed_features,
142179
)
180+
143181
clamped_candidates = columnwise_clamp(
144182
X=candidates, lower=lower_bounds, upper=upper_bounds, raise_on_violation=True
145183
)
146184
with torch.no_grad():
147185
batch_acquisition = acquisition_function(clamped_candidates)
186+
148187
return clamped_candidates, batch_acquisition
149188

150189

151190
def gen_candidates_torch(
152191
initial_conditions: Tensor,
153-
acquisition_function: Callable,
192+
acquisition_function: AcquisitionFunction,
154193
lower_bounds: Optional[Union[float, Tensor]] = None,
155194
upper_bounds: Optional[Union[float, Tensor]] = None,
156195
optimizer: Type[Optimizer] = torch.optim.Adam,
@@ -199,10 +238,41 @@ def gen_candidates_torch(
199238
)
200239
"""
201240
options = options or {}
241+
242+
# REDUCED is used indicate if we are optimizing over a reduced domain dimension
243+
# after considering fixed_features.
244+
REDUCED = fixed_features is not None
245+
246+
if REDUCED:
247+
_no_fixed_features = _remove_fixed_features_from_optimization(
248+
fixed_features=fixed_features,
249+
acquisition_function=acquisition_function,
250+
initial_conditions=initial_conditions,
251+
lower_bounds=lower_bounds,
252+
upper_bounds=upper_bounds,
253+
inequality_constraints=None,
254+
equality_constraints=None,
255+
)
256+
257+
# call the routine with no fixed_features
258+
clamped_candidates, batch_acquisition = gen_candidates_torch(
259+
initial_conditions=_no_fixed_features.initial_conditions,
260+
acquisition_function=_no_fixed_features.acquisition_function,
261+
lower_bounds=_no_fixed_features.lower_bounds,
262+
upper_bounds=_no_fixed_features.upper_bounds,
263+
optimizer=optimizer,
264+
options=options,
265+
verbose=verbose,
266+
fixed_features=None,
267+
)
268+
clamped_candidates = _no_fixed_features.acquisition_function._construct_X_full(
269+
clamped_candidates
270+
)
271+
return clamped_candidates, batch_acquisition
272+
202273
clamped_candidates = columnwise_clamp(
203274
X=initial_conditions, lower=lower_bounds, upper=upper_bounds
204275
).requires_grad_(True)
205-
candidates = fix_features(clamped_candidates, fixed_features)
206276
bayes_optimizer = optimizer(
207277
params=[clamped_candidates], lr=options.get("lr", 0.025)
208278
)
@@ -215,29 +285,33 @@ def gen_candidates_torch(
215285
)
216286
while not stop:
217287
i += 1
218-
loss = -acquisition_function(candidates).sum()
288+
loss = -acquisition_function(clamped_candidates).sum()
219289
if verbose:
220290
print("Iter: {} - Value: {:.3f}".format(i, -(loss.item())))
221291
loss_trajectory.append(loss.item())
222-
param_trajectory["candidates"].append(candidates.clone())
292+
param_trajectory["candidates"].append(clamped_candidates.clone())
223293

224294
def closure():
225295
bayes_optimizer.zero_grad()
226296
loss.backward()
227297
return loss
228298

229299
bayes_optimizer.step(closure)
230-
clamped_candidates.data = columnwise_clamp(
231-
clamped_candidates, lower_bounds, upper_bounds
232-
)
233-
candidates = fix_features(clamped_candidates, fixed_features)
300+
with torch.no_grad():
301+
clamped_candidates = columnwise_clamp(
302+
X=clamped_candidates, lower=lower_bounds, upper=upper_bounds
303+
)
234304
stop = stopping_criterion.evaluate(fvals=loss.detach())
235305
clamped_candidates = columnwise_clamp(
236-
X=candidates, lower=lower_bounds, upper=upper_bounds, raise_on_violation=True
306+
X=clamped_candidates,
307+
lower=lower_bounds,
308+
upper=upper_bounds,
309+
raise_on_violation=True,
237310
)
238311
with torch.no_grad():
239-
batch_acquisition = acquisition_function(candidates)
240-
return candidates, batch_acquisition
312+
batch_acquisition = acquisition_function(clamped_candidates)
313+
314+
return clamped_candidates, batch_acquisition
241315

242316

243317
def get_best_candidates(batch_candidates: Tensor, batch_values: Tensor) -> Tensor:

botorch/generation/utils.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66

77
from __future__ import annotations
88

9-
import typing # noqa F401
9+
from dataclasses import dataclass
10+
from typing import Dict, List, Optional, Tuple, Union
1011

1112
import torch
13+
from botorch.acquisition import AcquisitionFunction, FixedFeatureAcquisitionFunction
14+
from botorch.optim.parameter_constraints import _generate_unfixed_lin_constraints
1215
from torch import Tensor
1316

1417

@@ -44,3 +47,104 @@ def _flip_sub_unique(x: Tensor, k: int) -> Tensor:
4447
if len(out) >= k:
4548
break
4649
return x[idcs[: len(out)]]
50+
51+
52+
@dataclass(frozen=True, repr=False, eq=False)
53+
class _NoFixedFeatures:
54+
"""
55+
Dataclass to store the objects after removing fixed features.
56+
Objects here refer to the acquisition function, initial conditions,
57+
bounds and parameter constraints.
58+
"""
59+
60+
acquisition_function: FixedFeatureAcquisitionFunction
61+
initial_conditions: Tensor
62+
lower_bounds: Optional[Union[float, Tensor]]
63+
upper_bounds: Optional[Union[float, Tensor]]
64+
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]]
65+
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]]
66+
67+
68+
def _remove_fixed_features_from_optimization(
69+
fixed_features: Dict[int, Optional[float]],
70+
acquisition_function: AcquisitionFunction,
71+
initial_conditions: Tensor,
72+
lower_bounds: Optional[Union[float, Tensor]],
73+
upper_bounds: Optional[Union[float, Tensor]],
74+
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]],
75+
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]],
76+
) -> _NoFixedFeatures:
77+
"""
78+
Given a set of non-empty fixed features, this function effectively reduces the
79+
dimensionality of the domain that the acquisition function is being optimized
80+
over by removing the set of fixed features. Consequently, this function returns a
81+
new `FixedFeatureAcquisitionFunction`, new constraints, and bounds defined over
82+
unfixed features.
83+
84+
Args:
85+
fixed_features: This is a dictionary of feature indices to values, where
86+
all generated candidates will have features fixed to these values.
87+
If the dictionary value is None, then that feature will just be
88+
fixed to the clamped value and not optimized. Assumes values to be
89+
compatible with lower_bounds and upper_bounds!
90+
acquisition_function: Acquisition function over the original domain being
91+
maximized.
92+
initial_conditions: Starting points for optimization w.r.t. the complete domain.
93+
lower_bounds: Minimum values for each column of initial_conditions.
94+
upper_bounds: Minimum values for each column of initial_conditions.
95+
inequality constraints: A list of tuples (indices, coefficients, rhs),
96+
with each tuple encoding an inequality constraint of the form
97+
`sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
98+
equality constraints: A list of tuples (indices, coefficients, rhs),
99+
with each tuple encoding an inequality constraint of the form
100+
`sum_i (X[indices[i]] * coefficients[i]) = rhs`.
101+
102+
Returns:
103+
_NoFixedFeatures dataclass object.
104+
"""
105+
# sort the keys for consistency
106+
sorted_keys = sorted(fixed_features)
107+
sorted_values = []
108+
for key in sorted_keys:
109+
if fixed_features[key] is None:
110+
val = initial_conditions[..., [key]]
111+
else:
112+
val = fixed_features[key]
113+
sorted_values.append(val)
114+
115+
d = initial_conditions.shape[-1]
116+
acquisition_function = FixedFeatureAcquisitionFunction(
117+
acq_function=acquisition_function,
118+
d=d,
119+
columns=sorted_keys,
120+
values=sorted_values,
121+
)
122+
123+
# extract initial_conditions, bounds at unfixed indices
124+
unfixed_indices = sorted(set(range(d)) - set(sorted_keys))
125+
initial_conditions = initial_conditions[..., unfixed_indices]
126+
if isinstance(lower_bounds, Tensor):
127+
lower_bounds = lower_bounds[..., unfixed_indices]
128+
if isinstance(upper_bounds, Tensor):
129+
upper_bounds = upper_bounds[..., unfixed_indices]
130+
131+
inequality_constraints = _generate_unfixed_lin_constraints(
132+
constraints=inequality_constraints,
133+
fixed_features=fixed_features,
134+
dimension=d,
135+
eq=False,
136+
)
137+
equality_constraints = _generate_unfixed_lin_constraints(
138+
constraints=equality_constraints,
139+
fixed_features=fixed_features,
140+
dimension=d,
141+
eq=True,
142+
)
143+
return _NoFixedFeatures(
144+
acquisition_function=acquisition_function,
145+
initial_conditions=initial_conditions,
146+
lower_bounds=lower_bounds,
147+
upper_bounds=upper_bounds,
148+
inequality_constraints=inequality_constraints,
149+
equality_constraints=equality_constraints,
150+
)

botorch/optim/parameter_constraints.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import numpy as np
1717
import torch
18-
from botorch.exceptions.errors import UnsupportedError
18+
from botorch.exceptions.errors import UnsupportedError, CandidateGenerationError
1919
from scipy.optimize import Bounds
2020
from torch import Tensor
2121

@@ -267,3 +267,66 @@ def _make_linear_constraints(
267267
else:
268268
raise ValueError("`indices` must be at least one-dimensional")
269269
return constraints
270+
271+
272+
def _generate_unfixed_lin_constraints(
273+
constraints: Optional[List[Tuple[Tensor, Tensor, float]]],
274+
fixed_features: Dict[int, float],
275+
dimension: int,
276+
eq: bool,
277+
) -> Optional[List[Tuple[Tensor, Tensor, float]]]:
278+
279+
# If constraints is None or an empty list, then return itself
280+
if not constraints:
281+
return constraints
282+
283+
# replace_index generates the new indices for the unfixed dimensions
284+
# after eliminating the fixed dimensions.
285+
# Example: dimension = 5, ff.keys() = [1, 3], replace_index = {0: 0, 2: 1, 4: 2}
286+
unfixed_keys = sorted(set(range(dimension)) - set(fixed_features))
287+
unfixed_keys = torch.tensor(unfixed_keys).to(constraints[0][0])
288+
replace_index = torch.arange(dimension - len(fixed_features)).to(constraints[0][0])
289+
290+
new_constraints = []
291+
# parse constraints one-by-one
292+
for constraint_id, (indices, coefficients, rhs) in enumerate(constraints):
293+
new_rhs = rhs
294+
new_indices = []
295+
new_coefficients = []
296+
# the following unsqueeze is done to facilitate a simpler for-loop.
297+
indices_2dim = indices if indices.ndim == 2 else indices.unsqueeze(-1)
298+
for coefficient, index in zip(coefficients, indices_2dim):
299+
ffval_or_None = fixed_features.get(index[-1].item())
300+
# if ffval_or_None is None, then the index is not fixed
301+
if ffval_or_None is None:
302+
new_indices.append(index)
303+
new_coefficients.append(coefficient)
304+
# otherwise, we "remove" the constraints corresponding to that index
305+
else:
306+
new_rhs -= coefficient.item() * ffval_or_None
307+
308+
# all indices were fixed, so the constraint is gone.
309+
if len(new_indices) == 0:
310+
if (eq and new_rhs != 0) or (not eq and new_rhs > 0):
311+
prefix = "Eq" if eq else "Ineq"
312+
raise CandidateGenerationError(
313+
f"{prefix}ality constraint {constraint_id} not met "
314+
"with fixed_features."
315+
)
316+
else:
317+
# However, one key transformation has to be noted.
318+
# new_indices is with respect to the older (fuller) domain, and so it will
319+
# have to be converted using replace_index.
320+
new_indices = torch.stack(new_indices, dim=0)
321+
# generate new index location after the removal of fixed_features indices
322+
new_indices_dim_d = new_indices[:, -1].unsqueeze(-1)
323+
new_indices_dim_d = replace_index[
324+
torch.nonzero(new_indices_dim_d == unfixed_keys, as_tuple=True)[1]
325+
]
326+
new_indices[:, -1] = new_indices_dim_d
327+
# squeeze(-1) is a no-op if dim -1 is not singleton
328+
new_indices.squeeze_(-1)
329+
# convert new_coefficients to Tensor
330+
new_coefficients = torch.stack(new_coefficients)
331+
new_constraints.append((new_indices, new_coefficients, new_rhs))
332+
return new_constraints

0 commit comments

Comments
 (0)