Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 58 additions & 13 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from __future__ import annotations

import math
import warnings
from collections.abc import Callable

import torch
Expand All @@ -24,12 +25,16 @@
DeprecationError,
UnsupportedError,
)
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.fully_bayesian import MCMC_DIM
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.sampling.get_sampler import get_sampler
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
from botorch.utils.objective import compute_feasibility_indicator
from botorch.utils.objective import (
compute_feasibility_indicator,
compute_smoothed_feasibility_indicator,
)
from botorch.utils.sampling import optimize_posterior_samples
from botorch.utils.transforms import is_ensemble, normalize_indices
from gpytorch.models import GP
Expand Down Expand Up @@ -150,6 +155,13 @@ def compute_best_feasible_objective(
raise ValueError(
"Must specify `X_baseline` when no feasible observation exists."
)
warnings.warn(
"When all training points are infeasible, it is better to use "
"q(Log)ProbabilityOfFeasibility.",
BotorchWarning,
stacklevel=2,
)

infeasible_value = _estimate_objective_lower_bound(
model=model,
objective=objective,
Expand All @@ -171,8 +183,9 @@ def _estimate_objective_lower_bound(
posterior_transform: PosteriorTransform | None,
X: Tensor,
) -> Tensor:
"""Estimates a lower bound on the objective values by evaluating the model at convex
combinations of `X`, returning the 6-sigma lower bound of the computed statistics.
"""Estimates a lower bound on the objective values by evaluating the at uniformly
random points in the bounding box of `X`, returning the 6-sigma lower bound of the
computed statistics.

Args:
model: A fitted model.
Expand All @@ -183,19 +196,21 @@ def _estimate_objective_lower_bound(
Returns:
A `m`-dimensional Tensor of lower bounds of the objectives.
"""
convex_weights = torch.rand(
32,
X.shape[-2],
dtype=X.dtype,
device=X.device,
# we do not have access to `bounds` here, so we infer the bounding box
# from data, expanding by 10% in each direction
X_lb = X.min(dim=-2, keepdim=True).values
X_ub = X.max(dim=-2, keepdim=True).values
X_range = X_ub - X_lb
X_padding = 0.1 * X_range
uniform_samples = torch.rand(
*X.shape[:-2], 32, X.shape[-1], dtype=X.dtype, device=X.device
)
weights_sum = convex_weights.sum(dim=0, keepdim=True)
convex_weights = convex_weights / weights_sum
X_samples = X_lb - X_padding + uniform_samples * (X_range + 2 * X_padding)
# infeasible cost M is such that -M < min_x f(x), thus
# 0 < min_x f(x) - (-M), so we should take -M as a lower
# bound on the best feasible objective
return -get_infeasible_cost(
X=convex_weights @ X,
X=X_samples,
model=model,
objective=objective,
posterior_transform=posterior_transform,
Expand Down Expand Up @@ -235,7 +250,19 @@ def objective(Y: Tensor, X: Tensor | None = None):
return Y.squeeze(-1)

posterior = model.posterior(X, posterior_transform=posterior_transform)
lb = objective(posterior.mean - 6 * posterior.variance.clamp_min(0).sqrt(), X=X)
# We check both the upper and lower bound of the posterior, since the objective
# may be increasing or decreasing. For objectives that are neither (eg. absolute
# distance from a target), this should still provide a good bound.
six_stdv = 6 * posterior.variance.clamp_min(0).sqrt()
lb = torch.stack(
[
objective(posterior.mean - six_stdv, X=X),
objective(posterior.mean + six_stdv, X=X),
],
dim=0,
)
lb = lb.min(dim=0).values

if lb.ndim < posterior.mean.ndim:
lb = lb.unsqueeze(-1)
# Take outcome-wise min. Looping in to handle batched models.
Expand Down Expand Up @@ -311,6 +338,7 @@ def _prune_inferior_shared_processing(
samples=samples,
marginalize_dim=marginalize_dim,
)

return max_points, obj_vals, infeas


Expand Down Expand Up @@ -374,7 +402,24 @@ def prune_inferior_points(
sampler=sampler,
marginalize_dim=marginalize_dim,
)
if infeas.any():
if infeas.all():
# if no points are feasible, keep the point closest to being feasible
with torch.no_grad():
posterior = model.posterior(X=X, posterior_transform=posterior_transform)
if sampler is None:
sampler = get_sampler(
posterior=posterior, sample_shape=torch.Size([num_samples])
)
samples = sampler(posterior)
# use the probability of feasibility as the objective for computing best points
obj_vals = compute_smoothed_feasibility_indicator(
constraints=constraints,
samples=samples,
eta=1e-3,
log=True,
)

elif infeas.any():
# set infeasible points to worse than worst objective across all samples
# Use clone() here to avoid deprecated `index_put_` on an expanded tensor
obj_vals = obj_vals.clone()
Expand Down
20 changes: 12 additions & 8 deletions test/acquisition/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
DeprecationError,
UnsupportedError,
)
from botorch.exceptions.warnings import BotorchWarning
from botorch.models import SingleTaskGP
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from gpytorch.distributions import MultivariateNormal
Expand Down Expand Up @@ -154,14 +155,17 @@ def test_compute_best_feasible_objective(self):
def objective(Y, X):
return Y.squeeze(-1) - 5.0

best_f = compute_best_feasible_objective(
samples=samples,
obj=obj,
constraints=[lambda X: torch.ones_like(X[..., 0])],
model=mm,
X_baseline=X,
objective=objective,
)
with self.assertWarnsRegex(
BotorchWarning, "ProbabilityOfFeasibility"
):
best_f = compute_best_feasible_objective(
samples=samples,
obj=obj,
constraints=[lambda X: torch.ones_like(X[..., 0])],
model=mm,
X_baseline=X,
objective=objective,
)
expected_best_f = torch.full(
sample_shape + batch_shape,
-get_infeasible_cost(X=X, model=mm, objective=objective).item(),
Expand Down