Skip to content

Commit 9a45872

Browse files
sdaultonfacebook-github-bot
authored andcommitted
project point to feasible space via quadratic programming (#3010)
Summary: Pull Request resolved: #3010 see title. This is particularly useful for resolving numerical issues with Ax when it checks parameter constraints. Reviewed By: Balandat Differential Revision: D82328877 fbshipit-source-id: 3aa7d069a005f0ab95bcc91f98fe5309f38ebb53
1 parent eba2dce commit 9a45872

File tree

4 files changed

+830
-13
lines changed

4 files changed

+830
-13
lines changed

botorch/optim/optimize.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
gen_one_shot_kg_initial_conditions,
3636
TGenInitialConditions,
3737
)
38-
from botorch.optim.parameter_constraints import evaluate_feasibility
38+
from botorch.optim.parameter_constraints import (
39+
evaluate_feasibility,
40+
project_to_feasible_space_via_slsqp,
41+
)
3942
from botorch.optim.stopping import ExpMAStoppingCriterion
4043
from torch import Tensor
4144

@@ -513,15 +516,47 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
513516

514517
# SLSQP can sometimes fail to produce a feasible candidate. Check for
515518
# feasibility and error out if necessary.
519+
# if there are equality constraints, project the candidate to the feasible set
520+
equality_constraints = gen_kwargs.get("equality_constraints")
521+
inequality_constraints = gen_kwargs.get("inequality_constraints")
522+
nonlinear_inequality_constraints = gen_kwargs.get(
523+
"nonlinear_inequality_constraints"
524+
)
516525
is_feasible = evaluate_feasibility(
517526
X=batch_candidates,
518-
inequality_constraints=gen_kwargs.get("inequality_constraints"),
519-
equality_constraints=gen_kwargs.get("equality_constraints"),
520-
nonlinear_inequality_constraints=gen_kwargs.get(
521-
"nonlinear_inequality_constraints"
522-
),
527+
inequality_constraints=inequality_constraints,
528+
equality_constraints=equality_constraints,
529+
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
523530
)
524531
infeasible = ~is_feasible
532+
if nonlinear_inequality_constraints is None and infeasible.any():
533+
projected_candidates = project_to_feasible_space_via_slsqp(
534+
X=batch_candidates[infeasible],
535+
bounds=opt_inputs.bounds,
536+
equality_constraints=equality_constraints,
537+
inequality_constraints=inequality_constraints,
538+
)
539+
if opt_inputs.post_processing_func is not None:
540+
projected_candidates = opt_inputs.post_processing_func(projected_candidates)
541+
batch_candidates[infeasible] = projected_candidates
542+
# recompute AF values for projected points
543+
with torch.no_grad():
544+
batch_acq_values[infeasible] = torch.cat(
545+
[
546+
opt_inputs.acq_function(cand)
547+
for cand in projected_candidates.split(batch_limit, dim=0)
548+
],
549+
dim=0,
550+
)
551+
# re-evaluate feasibility
552+
is_feasible = evaluate_feasibility(
553+
X=batch_candidates,
554+
inequality_constraints=inequality_constraints,
555+
equality_constraints=equality_constraints,
556+
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
557+
)
558+
infeasible = ~is_feasible
559+
525560
if (opt_inputs.return_best_only and (not is_feasible.any())) or infeasible.all():
526561
raise CandidateGenerationError(
527562
f"The optimizer produced infeasible candidates. "

botorch/optim/parameter_constraints.py

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,32 @@
1818
import numpy.typing as npt
1919
import torch
2020
from botorch.exceptions.errors import CandidateGenerationError, UnsupportedError
21-
from scipy.optimize import Bounds
21+
from botorch.optim.utils import columnwise_clamp
22+
from scipy.optimize import Bounds, minimize
2223
from torch import Tensor
2324

2425

2526
ScipyConstraintDict = dict[
2627
str, Union[str, Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]
2728
]
28-
CONST_TOL = 1e-6
29+
30+
31+
def get_constraint_tolerance(dtype: torch.dtype) -> float:
32+
r"""Get the constraint tolerance for a given dtype.
33+
34+
Args:
35+
dtype: The dtype to use.
36+
37+
Returns:
38+
The constraint tolerance for the given dtype.
39+
"""
40+
if dtype == torch.double:
41+
return 1e-8
42+
elif dtype == torch.float:
43+
return 1e-6
44+
elif dtype == torch.half:
45+
return 1e-4
46+
raise ValueError(f"Unsupported dtype {dtype}.")
2947

3048

3149
def make_scipy_bounds(
@@ -513,7 +531,7 @@ def nonlinear_constraint_is_feasible(
513531
nonlinear_inequality_constraint: Callable,
514532
is_intrapoint: bool,
515533
x: Tensor,
516-
tolerance: float = CONST_TOL,
534+
tolerance: float | None = None,
517535
) -> Tensor:
518536
"""Checks if a nonlinear inequality constraint is fulfilled (within tolerance).
519537
@@ -533,6 +551,8 @@ def nonlinear_constraint_is_feasible(
533551
A boolean tensor of shape (batch) indicating if the constraint is
534552
satified by the corresponding batch of `x`.
535553
"""
554+
if tolerance is None:
555+
tolerance = get_constraint_tolerance(dtype=x.dtype)
536556

537557
def check_x(x: Tensor) -> bool:
538558
return _arrayify(nonlinear_inequality_constraint(x)).item() >= -tolerance
@@ -615,7 +635,7 @@ def evaluate_feasibility(
615635
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
616636
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
617637
nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None,
618-
tolerance: float = CONST_TOL,
638+
tolerance: float | None = None,
619639
) -> Tensor:
620640
r"""Evaluate feasibility of candidate points (within a tolerance).
621641
@@ -657,6 +677,9 @@ def evaluate_feasibility(
657677
A boolean tensor of shape `batch` indicating if the corresponding candidate of
658678
shape `q x d` is feasible.
659679
"""
680+
if tolerance is None:
681+
tolerance = get_constraint_tolerance(dtype=X.dtype)
682+
660683
is_feasible = torch.ones(X.shape[:-2], device=X.device, dtype=torch.bool)
661684
if inequality_constraints is not None:
662685
for idx, coef, rhs in inequality_constraints:
@@ -691,3 +714,78 @@ def evaluate_feasibility(
691714
tolerance=tolerance,
692715
)
693716
return is_feasible
717+
718+
719+
def project_to_feasible_space_via_slsqp(
720+
X: Tensor,
721+
bounds: Tensor,
722+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
723+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
724+
) -> Tensor:
725+
"""Project X onto the feasible space by solving a quadratic program.
726+
727+
This uses SLSQP with gradients to solve the quadratic program.
728+
NOTE: A proper specialized QP solver would be a better choice here,
729+
but we'd like to avoid adding dependency on additional packages.
730+
SLSQP should be able to solve this reliably and quickly since the
731+
dimension is typically low and the number of constraints is typically
732+
limited.
733+
734+
Args:
735+
X: A `(batch_shape x) n x d`-dim tensor of inptus.
736+
bounds: A `2 x d`-dim tensor of lower and upper bounds.
737+
inequality_constraints: A list of tuples (indices, coefficients, rhs),
738+
with each tuple encoding an inequality constraint of the form
739+
`sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and
740+
`coefficients` should be torch tensors. See the docstring of
741+
`make_scipy_linear_constraints` for an example.
742+
equality_constraints: A list of tuples (indices, coefficients, rhs).
743+
744+
Returns:
745+
A `(batch_shape x) n x d`-dim tensor of projected values.
746+
"""
747+
if inequality_constraints is None and equality_constraints is None:
748+
return X
749+
bounds_scipy = make_scipy_bounds(
750+
X=X, lower_bounds=bounds[0], upper_bounds=bounds[1]
751+
)
752+
constraints = make_scipy_linear_constraints(
753+
shapeX=X.shape,
754+
inequality_constraints=inequality_constraints,
755+
equality_constraints=equality_constraints,
756+
)
757+
# Define squared distance objective
758+
X_np = X.flatten().detach().cpu().numpy()
759+
760+
def objective(x: np.ndarray):
761+
return 0.5 * np.sum((x - X_np) ** 2)
762+
763+
def grad_objective(x: np.ndarray):
764+
return x - X_np
765+
766+
x0 = (
767+
columnwise_clamp(X=X, lower=bounds[0], upper=bounds[1], raise_on_violation=True)
768+
.detach()
769+
.cpu()
770+
.numpy()
771+
.flatten()
772+
)
773+
# NOTE: A proper specialized QP solver would be a better choice here,
774+
# but we'd like to avoid adding dependency on additional packages.
775+
# SLSQP should be able to solve this reliably and quickly since the
776+
# dimension is typically low and the number of constraints is typically
777+
# limited.
778+
result = minimize(
779+
fun=objective,
780+
x0=x0,
781+
method="SLSQP",
782+
jac=grad_objective,
783+
bounds=bounds_scipy,
784+
constraints=constraints,
785+
tol=get_constraint_tolerance(dtype=X.dtype),
786+
)
787+
788+
if not result.success:
789+
raise RuntimeError(f"Optimization failed: {result.message}")
790+
791+
return torch.from_numpy(result.x).to(X).view(X.shape)

0 commit comments

Comments
 (0)