Skip to content

Commit 86cac04

Browse files
Utilities for estimating feasible volume (#437)
Summary: Pull Request resolved: #437 Estimating feasible volume via Monte Carlo method. First, the code calculates the proportion of features uniformly sampled from a box that also satisfy linear constraints. Second, the code calculates the proportion of these feasible features for which posterior samples satisfy the outcome constraints with probability above a given lower bound. Reviewed By: Balandat Differential Revision: D21355820 fbshipit-source-id: be603815e524d9840d675f8e23cd24e3dfa7a720
1 parent 62578b4 commit 86cac04

File tree

5 files changed

+282
-2
lines changed

5 files changed

+282
-2
lines changed

botorch/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from .constraints import get_outcome_constraint_transforms
8+
from .feasible_volume import estimate_feasible_volume
89
from .objective import apply_constraints, get_objective_weights_transform
910
from .sampling import (
1011
batched_multinomial,
@@ -30,4 +31,5 @@
3031
"squeeze_last_dim",
3132
"standardize",
3233
"t_batch_mode_transform",
34+
"estimate_feasible_volume",
3335
]

botorch/utils/feasible_volume.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from typing import Callable, List, Optional, Tuple
10+
11+
import botorch.models.model as model
12+
import torch
13+
from torch import Tensor
14+
15+
from ..logging import _get_logger
16+
from .sampling import manual_seed
17+
18+
19+
logger = _get_logger(name="Feasibility")
20+
21+
22+
def get_feasible_samples(
23+
samples: Tensor,
24+
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
25+
) -> Tuple[Tensor, float]:
26+
r"""
27+
Checks which of the samples satisfy all of the inequality constraints.
28+
29+
Args:
30+
samples: A `sample size x d` size tensor of feature samples,
31+
where d is a feature dimension.
32+
inequality constraints: A list of tuples (indices, coefficients, rhs),
33+
with each tuple encoding an inequality constraint of the form
34+
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
35+
Returns:
36+
2-element tuple containing
37+
38+
- Samples satisfying the linear constraints.
39+
- Estimated proportion of samples satisfying the linear constraints.
40+
"""
41+
42+
if inequality_constraints is None:
43+
return samples, 1.0
44+
45+
nsamples = samples.size(0)
46+
47+
feasible = torch.ones(nsamples, device=samples.device, dtype=torch.bool)
48+
49+
for (indices, coefficients, rhs) in inequality_constraints:
50+
feasible &= samples.index_select(1, indices) @ coefficients >= rhs
51+
52+
feasible_samples = samples[feasible]
53+
54+
p_linear = feasible_samples.size(0) / nsamples
55+
56+
return feasible_samples, p_linear
57+
58+
59+
def get_outcome_feasibility_probability(
60+
model: model.Model,
61+
X: Tensor,
62+
outcome_constraints: List[Callable[[Tensor], Tensor]],
63+
threshold: float = 0.1,
64+
nsample_outcome: int = 1000,
65+
seed: Optional[int] = None,
66+
) -> float:
67+
r"""
68+
Monte Carlo estimate of the feasible volume with respect to the outcome constraints.
69+
70+
Args:
71+
model: The model used for sampling the posterior.
72+
X: A tensor of dimension `batch-shape x 1 x d`, where d is feature dimension.
73+
outcome_constraints: A list of callables, each mapping a Tensor of dimension
74+
`sample_shape x batch-shape x q x m` to a Tensor of dimension
75+
`sample_shape x batch-shape x q`, where negative values imply feasibility.
76+
threshold: A lower limit for the probability of posterior samples feasibility.
77+
nsample_outcome: The number of samples from the model posterior.
78+
seed: The seed for the posterior sampler. If omitted, use a random seed.
79+
80+
Returns:
81+
Estimated proportion of features for which posterior samples satisfy
82+
given outcome constraints with probability above or equal to
83+
the given threshold.
84+
"""
85+
from botorch.sampling import SobolQMCNormalSampler
86+
87+
seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
88+
89+
posterior = model.posterior(X) # posterior consists of batch_shape marginals
90+
sampler = SobolQMCNormalSampler(num_samples=nsample_outcome, seed=seed)
91+
# size of samples: (num outcome samples, batch_shape, 1, outcome dim)
92+
samples = sampler(posterior)
93+
94+
feasible = torch.ones(samples.shape[:-1], dtype=torch.bool, device=samples.device)
95+
96+
# a sample passes if each constraint applied to the sample
97+
# produces a non-negative tensor
98+
for oc in outcome_constraints:
99+
# broadcasted evaluation of the outcome constraints
100+
feasible &= oc(samples) <= 0
101+
102+
# proportion of feasibile samples for each of the elements of X
103+
# summation is done across feasible outcome samples
104+
p_feas = feasible.sum(0).float() / feasible.size(0)
105+
106+
# proportion of features leading to the posterior outcome
107+
# satisfying the given outcome constraints
108+
# with at probability above a given threshold
109+
p_outcome = (p_feas >= threshold).sum().item() / X.size(0)
110+
111+
return p_outcome
112+
113+
114+
def estimate_feasible_volume(
115+
bounds: Tensor,
116+
model: model.Model,
117+
outcome_constraints: List[Callable[[Tensor], Tensor]],
118+
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
119+
nsample_feature: int = 1000,
120+
nsample_outcome: int = 1000,
121+
threshold: float = 0.1,
122+
verbose: bool = False,
123+
seed: Optional[int] = None,
124+
device: Optional[torch.device] = None,
125+
dtype: Optional[torch.dtype] = None,
126+
) -> Tuple[float, float]:
127+
r"""
128+
Monte Carlo estimate of the feasible volume with respect
129+
to feature constraints and outcome constraints.
130+
131+
Args:
132+
bounds: A `2 x d` tensor of lower and upper bounds
133+
for each column of `X`.
134+
model: The model used for sampling the outcomes.
135+
outcome_constraints: A list of callables, each mapping a Tensor of dimension
136+
`sample_shape x batch-shape x q x m` to a Tensor of dimension
137+
`sample_shape x batch-shape x q`, where negative values imply
138+
feasibility.
139+
inequality constraints: A list of tuples (indices, coefficients, rhs),
140+
with each tuple encoding an inequality constraint of the form
141+
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
142+
nsample_feature: The number of feature samples satisfying the bounds.
143+
nsample_outcome: The number of outcome samples from the model posterior.
144+
threshold: A lower limit for the probability of outcome feasibility
145+
seed: The seed for both feature and outcome samplers. If omitted,
146+
use a random seed.
147+
verbose: An indicator for whether to log the results.
148+
149+
Returns:
150+
2-element tuple containing:
151+
152+
- Estimated proportion of volume in feature space that is
153+
feasible wrt the bounds and the inequality constraints (linear).
154+
- Estimated proportion of feasible features for which
155+
posterior samples (outcome) satisfies the outcome constraints
156+
with probability above the given threshold.
157+
"""
158+
159+
seed = seed if seed is not None else torch.randint(0, 1000000, (1,)).item()
160+
161+
with manual_seed(seed=seed):
162+
box_samples = bounds[0] + (bounds[1] - bounds[0]) * torch.rand(
163+
(nsample_feature, bounds.size(1)), dtype=dtype, device=device
164+
)
165+
166+
features, p_feature = get_feasible_samples(
167+
samples=box_samples, inequality_constraints=inequality_constraints
168+
) # each new feature sample is a row
169+
170+
p_outcome = get_outcome_feasibility_probability(
171+
model=model,
172+
X=features.unsqueeze(-2),
173+
outcome_constraints=outcome_constraints,
174+
threshold=threshold,
175+
nsample_outcome=nsample_outcome,
176+
seed=seed,
177+
)
178+
179+
if verbose: # pragma: no cover
180+
logger.info(
181+
"Proportion of volume that satisfies linear constraints: "
182+
+ f"{p_feature:.4e}"
183+
)
184+
if p_feature <= 0.01:
185+
logger.warning(
186+
"The proportion of satisfying volume is very low and may lead to "
187+
+ "very long run times. Consider making your constraints less "
188+
+ "restrictive."
189+
)
190+
logger.info(
191+
"Proportion of linear-feasible volume that also satisfies each "
192+
+ f"outcome constraint with probability > 0.1: {p_outcome:.4e}"
193+
)
194+
if p_outcome <= 0.001:
195+
logger.warning(
196+
"The proportion of volume that also satisfies the outcome constraint "
197+
+ "is very low. Consider making your parameter and outcome constraints "
198+
+ "less restrictive."
199+
)
200+
return p_feature, p_outcome

botorch/utils/testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
from unittest import TestCase
1414

1515
import torch
16-
from botorch.posteriors.gpytorch import GPyTorchPosterior
1716
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
1817
from gpytorch.lazy import AddedDiagLazyTensor, DiagLazyTensor
1918
from torch import Tensor
2019

2120
from .. import settings
2221
from ..models.model import Model
23-
from ..posteriors import Posterior
22+
from ..posteriors.gpytorch import GPyTorchPosterior
23+
from ..posteriors.posterior import Posterior
2424
from ..test_functions.base import BaseTestProblem
2525

2626

sphinx/source/utils.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,8 @@ Transformations
3131
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3232
.. automodule:: botorch.utils.transforms
3333
:members:
34+
35+
Feasible Volume
36+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
.. automodule:: botorch.utils.feasible_volume
38+
:members:

test/utils/test_feasible_volume.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#! /usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
from botorch.utils.feasible_volume import (
10+
estimate_feasible_volume,
11+
get_feasible_samples,
12+
get_outcome_feasibility_probability,
13+
)
14+
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
15+
16+
17+
class TestFeasibleVolumeEstimates(BotorchTestCase):
18+
def test_feasible_samples(self):
19+
# -X[0]+X[1]>=1
20+
inequality_constraints = [(torch.tensor([0, 1]), torch.tensor([-1.0, 1.0]), 1)]
21+
box_samples = torch.tensor([[1.1, 2.0], [0.9, 2.1], [1.5, 2], [1.8, 2.2]])
22+
23+
feasible_samples, p_linear = get_feasible_samples(
24+
samples=box_samples, inequality_constraints=inequality_constraints
25+
)
26+
27+
feasible = box_samples[:, 1] - box_samples[:, 0] >= 1
28+
29+
self.assertTrue(
30+
torch.all(torch.eq(feasible_samples, box_samples[feasible])).item()
31+
)
32+
self.assertEqual(p_linear, feasible.sum(0).float().item() / feasible.size(0))
33+
34+
def test_outcome_feasibility_probability(self):
35+
for dtype in (torch.float, torch.double):
36+
samples = torch.zeros(1, 1, 1, device=self.device, dtype=dtype)
37+
mm = MockModel(MockPosterior(samples=samples))
38+
X = torch.zeros(1, 1, device=self.device, dtype=torch.double)
39+
40+
for outcome_constraints in [
41+
[lambda y: y[..., 0] - 0.5],
42+
[lambda y: y[..., 0] + 1.0],
43+
]:
44+
p_outcome = get_outcome_feasibility_probability(
45+
model=mm,
46+
X=X,
47+
outcome_constraints=outcome_constraints,
48+
nsample_outcome=2,
49+
)
50+
feasible = outcome_constraints[0](samples) <= 0
51+
self.assertEqual(p_outcome, feasible)
52+
53+
def test_estimate_feasible_volume(self):
54+
for dtype in (torch.float, torch.double):
55+
for samples in (
56+
torch.zeros(1, 2, 1, device=self.device, dtype=dtype),
57+
torch.ones(1, 1, 1, device=self.device, dtype=dtype),
58+
):
59+
60+
mm = MockModel(MockPosterior(samples=samples))
61+
bounds = torch.ones((2, 1))
62+
outcome_constraints = [lambda y: y[..., 0] - 0.5]
63+
64+
p_linear, p_outcome = estimate_feasible_volume(
65+
bounds=bounds,
66+
model=mm,
67+
outcome_constraints=outcome_constraints,
68+
nsample_feature=2,
69+
nsample_outcome=1,
70+
)
71+
72+
self.assertEqual(p_linear, 1.0)
73+
self.assertEqual(p_outcome, 1.0 - samples[0, 0].item())

0 commit comments

Comments
 (0)