Skip to content

Commit b717268

Browse files
sdaultonfacebook-github-bot
authored andcommitted
handle constraints in qSimpleRegret (#2141)
Summary: Pull Request resolved: #2141 This fixes two issues: 1. `constraints` could be passed to `qSimpleRegret`, but would error out if the objective was negative. 2. `constraints` were not used in the input constructor for `qSimpleRegret` in any fashion. This now enforces that constraints cannot be passed to `qSimpleRegret`, and instead, they must be passed via a `ConstrainedMCObjective`. This diff also constructs the appropriate `ConstrainedMCObjective` in the input constructor. Reviewed By: Balandat Differential Revision: D51964703 fbshipit-source-id: 43cf60f50457a2bc89c9454a0add9c917e585ce4
1 parent 449b911 commit b717268

File tree

4 files changed

+124
-1
lines changed

4 files changed

+124
-1
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
)
8282
from botorch.acquisition.multi_objective.utils import get_default_partitioning_alpha
8383
from botorch.acquisition.objective import (
84+
ConstrainedMCObjective,
8485
IdentityMCObjective,
8586
MCAcquisitionObjective,
8687
PosteriorTransform,
@@ -90,6 +91,7 @@
9091
from botorch.acquisition.utils import (
9192
compute_best_feasible_objective,
9293
expand_trace_observations,
94+
get_infeasible_cost,
9395
get_optimal_samples,
9496
project_to_target_fidelity,
9597
)
@@ -433,6 +435,8 @@ def construct_inputs_qSimpleRegret(
433435
posterior_transform: Optional[PosteriorTransform] = None,
434436
X_pending: Optional[Tensor] = None,
435437
sampler: Optional[MCSampler] = None,
438+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
439+
X_baseline: Optional[Tensor] = None,
436440
) -> Dict[str, Any]:
437441
r"""Construct kwargs for qSimpleRegret.
438442
@@ -446,10 +450,28 @@ def construct_inputs_qSimpleRegret(
446450
but have not yet been evaluated.
447451
sampler: The sampler used to draw base samples. If omitted, uses
448452
the acquisition functions's default sampler.
453+
constraints: A list of constraint callables which map a Tensor of posterior
454+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
455+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
456+
are considered satisfied if the output is less than zero.
457+
X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points
458+
that have already been observed. These points are considered as
459+
the potential best design point. If omitted, checks that all
460+
training_data have the same input features and take the first `X`.
449461
450462
Returns:
451463
A dict mapping kwarg names of the constructor to values.
452464
"""
465+
if constraints is not None:
466+
if X_baseline is None:
467+
raise ValueError("Constraints require an X_baseline.")
468+
objective = ConstrainedMCObjective(
469+
objective=objective,
470+
constraints=constraints,
471+
infeasible_cost=get_infeasible_cost(
472+
X=X_baseline, model=model, objective=objective
473+
),
474+
)
453475
return {
454476
"model": model,
455477
"objective": objective,

botorch/acquisition/monte_carlo.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,13 +729,54 @@ class qSimpleRegret(SampleReducingMCAcquisitionFunction):
729729
730730
`qSR(X) = E(max Y), Y ~ f(X), X = (x_1,...,x_q)`
731731
732+
Constraints should be provided as a `ConstrainedMCObjective`.
733+
Passing `constraints` as an argument is not supported. This is because
734+
`SampleReducingMCAcquisitionFunction` computes the acquisition values on the sample
735+
level and then weights the sample-level acquisition values by a soft feasibility
736+
indicator. Hence, it expects non-log acquisition function values to be
737+
non-negative. `qSimpleRegret` acquisition values can be negative, so we instead use
738+
a `ConstrainedMCObjective` which applies constraints to the objectives (e.g. before
739+
computing the acquisition function) and shifts negative objective values using
740+
by an infeasible cost to ensure non-negativity (before applying constraints and
741+
shifting them back).
742+
732743
Example:
733744
>>> model = SingleTaskGP(train_X, train_Y)
734745
>>> sampler = SobolQMCNormalSampler(1024)
735746
>>> qSR = qSimpleRegret(model, sampler)
736747
>>> qsr = qSR(test_X)
737748
"""
738749

750+
def __init__(
751+
self,
752+
model: Model,
753+
sampler: Optional[MCSampler] = None,
754+
objective: Optional[MCAcquisitionObjective] = None,
755+
posterior_transform: Optional[PosteriorTransform] = None,
756+
X_pending: Optional[Tensor] = None,
757+
) -> None:
758+
r"""q-Simple Regret.
759+
760+
Args:
761+
model: A fitted model.
762+
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
763+
more details.
764+
objective: The MCAcquisitionObjective under which the samples are
765+
evaluated. Defaults to `IdentityMCObjective()`.
766+
posterior_transform: A PosteriorTransform (optional).
767+
X_pending: A `m x d`-dim Tensor of `m` design points that have
768+
points that have been submitted for function evaluation
769+
but have not yet been evaluated. Concatenated into X upon
770+
forward call. Copied and set to have no gradient.
771+
"""
772+
super().__init__(
773+
model=model,
774+
sampler=sampler,
775+
objective=objective,
776+
posterior_transform=posterior_transform,
777+
X_pending=X_pending,
778+
)
779+
739780
def _sample_forward(self, obj: Tensor) -> Tensor:
740781
r"""Evaluate qSimpleRegret per sample on the candidate set `X`.
741782
@@ -757,6 +798,17 @@ class qUpperConfidenceBound(SampleReducingMCAcquisitionFunction):
757798
`qUCB = E(max(mu + |Y_tilde - mu|))`, where `Y_tilde ~ N(mu, beta pi/2 Sigma)`
758799
and `f(X)` has distribution `N(mu, Sigma)`.
759800
801+
Constraints should be provided as a `ConstrainedMCObjective`.
802+
Passing `constraints` as an argument is not supported. This is because
803+
`SampleReducingMCAcquisitionFunction` computes the acquisition values on the sample
804+
level and then weights the sample-level acquisition values by a soft feasibility
805+
indicator. Hence, it expects non-log acquisition function values to be
806+
non-negative. `qSimpleRegret` acquisition values can be negative, so we instead use
807+
a `ConstrainedMCObjective` which applies constraints to the objectives (e.g. before
808+
computing the acquisition function) and shifts negative objective values using
809+
by an infeasible cost to ensure non-negativity (before applying constraints and
810+
shifting them back).
811+
760812
Example:
761813
>>> model = SingleTaskGP(train_X, train_Y)
762814
>>> sampler = SobolQMCNormalSampler(1024)

test/acquisition/test_input_constructors.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
)
7575
from botorch.acquisition.multi_objective.utils import get_default_partitioning_alpha
7676
from botorch.acquisition.objective import (
77+
ConstrainedMCObjective,
7778
LinearMCObjective,
7879
ScalarizedPosteriorTransform,
7980
)
@@ -473,6 +474,35 @@ def test_construct_inputs_mc_base(self) -> None:
473474
self.assertIsNone(kwargs["sampler"])
474475
acqf = qSimpleRegret(**kwargs)
475476
self.assertIs(acqf.model, mock_model)
477+
# test constraints
478+
constraints = [lambda Y: Y[..., 0]]
479+
with self.assertRaisesRegex(ValueError, "Constraints require an X_baseline."):
480+
c(
481+
model=mock_model,
482+
training_data=self.blockX_blockY,
483+
objective=objective,
484+
X_pending=X_pending,
485+
constraints=constraints,
486+
)
487+
with mock.patch(
488+
"botorch.acquisition.input_constructors.get_infeasible_cost",
489+
return_value=2.0,
490+
):
491+
kwargs = c(
492+
model=mock_model,
493+
training_data=self.blockX_blockY,
494+
objective=objective,
495+
X_pending=X_pending,
496+
constraints=constraints,
497+
X_baseline=X_pending,
498+
)
499+
acqf = qSimpleRegret(**kwargs)
500+
self.assertIsNone(acqf._constraints)
501+
self.assertIsInstance(acqf.objective, ConstrainedMCObjective)
502+
self.assertIs(acqf.objective.objective, objective)
503+
self.assertIs(acqf.objective.constraints, constraints)
504+
self.assertEqual(acqf.objective.infeasible_cost.item(), 2.0)
505+
476506
# TODO: Test passing through of sampler
477507

478508
def test_construct_inputs_qEI(self) -> None:

test/acquisition/test_monte_carlo.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ def _sample_forward(self, X):
4949
pass
5050

5151

52+
class NegativeReducingMCAcquisitionFunction(SampleReducingMCAcquisitionFunction):
53+
def _sample_forward(self, X):
54+
return torch.full_like(X, -1.0)
55+
56+
5257
def infeasible_con(samples: Tensor) -> Tensor:
5358
return torch.ones_like(samples[..., 0])
5459

@@ -806,6 +811,18 @@ def test_q_simple_regret_batch(self):
806811
acqf(X)
807812
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
808813

814+
def test_q_simple_regret_constraints(self):
815+
# basic test that passing constraints directly is not allowed
816+
samples = torch.zeros(2, 2, 1, device=self.device, dtype=torch.double)
817+
samples[0, 0, 0] = 1.0
818+
mm = MockModel(MockPosterior(samples=samples))
819+
regex = (
820+
r"qSimpleRegret\.__init__\(\) got an unexpected keyword argument "
821+
r"'constraints'"
822+
)
823+
with self.assertRaisesRegex(TypeError, regex):
824+
qSimpleRegret(model=mm, constraints=[lambda Y: Y[..., 0]])
825+
809826
# TODO: Test different objectives (incl. constraints)
810827

811828

@@ -988,7 +1005,9 @@ def test_mc_acquisition_function_with_constraints(self):
9881005
# regret because the acquisition utility is negative.
9891006
samples = -torch.rand(n, q, m, device=self.device, dtype=dtype)
9901007
mm = MockModel(MockPosterior(samples=samples))
991-
cacqf = qSimpleRegret(model=mm, constraints=[feasible_con])
1008+
cacqf = NegativeReducingMCAcquisitionFunction(
1009+
model=mm, constraints=[feasible_con]
1010+
)
9921011
with self.assertRaisesRegex(
9931012
ValueError,
9941013
"Constraint-weighting requires unconstrained "

0 commit comments

Comments
 (0)