Skip to content

Commit 7ac20d8

Browse files
hvarfnerfacebook-github-bot
authored andcommitted
Performance & runtime improvements to info-theoretic acquisition functions (1/N)
Summary: A series of improvements directed towards improving the performance of PES & JES, as well as their MultiObj counterparts. ## Motivation As pointed out by SebastianAment in [this paper](https://arxiv.org/pdf/2310.20708), the BoTorch variant of JES, and to an extent PES, is brutally slow an suspiciously ill-performing. To bring them up to their potential, I've added a series of performance improvements: **1. Improvement to get_optimal_samples and optimal_posterior_samples**: As this is an integral part of their efficiency, I've added `suggestions` (similar to `sample_around_best`) to `optimize_posterior_samples`. Marginal runtime improvement in acquisition optimization (sampling time practically unchanged): ![runtime_pr1](https://github.com/user-attachments/assets/a4e6d24b-2d77-4f6f-bb79-4dea8ac9304a) Substantial performance improvement: ![pr1_regret](https://github.com/user-attachments/assets/d534f716-9a78-43db-acf4-3a795ca1e597). **2. Added initializer to acquisition funcction optimization**: Similar to KG, ES methods have sensible suggestions for acquisition function optimization in the form of the sampled optima. This drastically reduces the time of acquisition function optimization, which could on occasion take 30+ seconds when `num_restarts` was large `>4`. Benchmarking INC **2b. Multi-objective support for initializer**: By re-naming arguments of the multi-objective variants, we get consistency and support for MO variants. **3. Enabled gradient-based optimization for PES**: The current implementation contains a while-loop which forces the gradients to be recursively computed. This commonly causes NaN gradients, which is why the recommended option is `"with_grad": False` in the tutorial. One `detach()` alleviates this issue, enabling gradient-based optimization. NOTE: this has NOT been ablated, since the non-grad optimization is extremely computationally demanding. X-link: #2748 Reviewed By: saitcakmak Differential Revision: D69787454 Pulled By: hvarfner
1 parent ee328f2 commit 7ac20d8

File tree

4 files changed

+81
-26
lines changed

4 files changed

+81
-26
lines changed

botorch/acquisition/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def get_optimal_samples(
579579
posterior_transform: ScalarizedPosteriorTransform | None = None,
580580
objective: MCAcquisitionObjective | None = None,
581581
return_transformed: bool = False,
582+
options: dict | None = None,
582583
) -> tuple[Tensor, Tensor]:
583584
"""Draws sample paths from the posterior and maximizes the samples using GD.
584585
@@ -596,7 +597,8 @@ def get_optimal_samples(
596597
objective: An MCAcquisitionObjective, used to negate the objective or otherwise
597598
transform sample outputs. Cannot be combined with `posterior_transform`.
598599
return_transformed: If True, return the transformed samples.
599-
600+
options: Options for generation of initial candidates, passed to
601+
gen_batch_initial_conditions.
600602
Returns:
601603
The optimal input locations and corresponding outputs, x* and f*.
602604
@@ -625,12 +627,20 @@ def get_optimal_samples(
625627
sample_shape=torch.Size([num_optima]),
626628
ensemble_as_batch=True,
627629
)
630+
suggested_points = prune_inferior_points(
631+
model=model,
632+
X=model.train_inputs[0],
633+
posterior_transform=posterior_transform,
634+
objective=objective,
635+
)
628636
optimal_inputs, optimal_outputs = optimize_posterior_samples(
629637
paths=paths,
630638
bounds=bounds,
631639
raw_samples=raw_samples,
632640
num_restarts=num_restarts,
633641
sample_transform=sample_transform,
634642
return_transformed=return_transformed,
643+
suggested_points=suggested_points,
644+
options=options,
635645
)
636646
return optimal_inputs, optimal_outputs

botorch/utils/sampling.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,10 +1005,12 @@ def sparse_to_dense_constraints(
10051005
def optimize_posterior_samples(
10061006
paths: GenericDeterministicModel,
10071007
bounds: Tensor,
1008-
raw_samples: int = 1024,
1009-
num_restarts: int = 20,
1008+
raw_samples: int = 2048,
1009+
num_restarts: int = 4,
10101010
sample_transform: Callable[[Tensor], Tensor] | None = None,
10111011
return_transformed: bool = False,
1012+
suggested_points: Tensor | None = None,
1013+
options: dict | None = None,
10121014
) -> tuple[Tensor, Tensor]:
10131015
r"""Cheaply maximizes posterior samples by random querying followed by
10141016
gradient-based optimization using SciPy's L-BFGS-B routine.
@@ -1017,19 +1019,27 @@ def optimize_posterior_samples(
10171019
paths: Random Fourier Feature-based sample paths from the GP
10181020
bounds: The bounds on the search space.
10191021
raw_samples: The number of samples with which to query the samples initially.
1022+
Raw samples are cheap to evaluate, so this should ideally be set much higher
1023+
than num_restarts.
10201024
num_restarts: The number of points selected for gradient-based optimization.
1025+
Should be set low relative to the number of raw samples for time-efficiency.
10211026
sample_transform: A callable transform of the sample outputs (e.g.
10221027
MCAcquisitionObjective or ScalarizedPosteriorTransform.evaluate) used to
10231028
negate the objective or otherwise transform the output.
10241029
return_transformed: A boolean indicating whether to return the transformed
10251030
or non-transformed samples.
1031+
suggested_points: Tensor of suggested input locations that are high-valued.
1032+
These are more densely evaluated during the sampling phase of optimization.
1033+
options: Options for generation of initial candidates, passed to
1034+
gen_batch_initial_conditions.
10261035
10271036
Returns:
10281037
A two-element tuple containing:
10291038
- X_opt: A `num_optima x [batch_size] x d`-dim tensor of optimal inputs x*.
10301039
- f_opt: A `num_optima x [batch_size] x m`-dim, optionally
10311040
`num_optima x [batch_size] x 1`-dim, tensor of optimal outputs f*.
10321041
"""
1042+
options = {} if options is None else options
10331043

10341044
def path_func(x) -> Tensor:
10351045
res = paths(x)
@@ -1038,21 +1048,35 @@ def path_func(x) -> Tensor:
10381048

10391049
return res.squeeze(-1)
10401050

1041-
candidate_set = unnormalize(
1042-
SobolEngine(dimension=bounds.shape[1], scramble=True).draw(n=raw_samples),
1043-
bounds=bounds,
1044-
)
10451051
# queries all samples on all candidates - output shape
10461052
# raw_samples x num_optima x num_models
1053+
frac_random = 1 if suggested_points is None else options.get("frac_random", 0.9)
1054+
candidate_set = draw_sobol_samples(
1055+
bounds=bounds, n=round(raw_samples * frac_random), q=1
1056+
).squeeze(-2)
1057+
if frac_random < 1:
1058+
perturbed_suggestions = sample_truncated_normal_perturbations(
1059+
X=suggested_points,
1060+
n_discrete_points=round(raw_samples * (1 - frac_random)),
1061+
sigma=options.get("sample_around_best_sigma", 1e-2),
1062+
bounds=bounds,
1063+
)
1064+
candidate_set = torch.cat((candidate_set, perturbed_suggestions))
1065+
10471066
candidate_queries = path_func(candidate_set)
1048-
argtop_k = torch.topk(candidate_queries, num_restarts, dim=-1).indices
1049-
X_top_k = candidate_set[argtop_k, :]
1067+
idx = boltzmann_sample(
1068+
function_values=candidate_queries.unsqueeze(-1),
1069+
num_samples=num_restarts,
1070+
eta=options.get("eta", 2.0),
1071+
replacement=False,
1072+
)
1073+
ics = candidate_set[idx, :]
10501074

10511075
# to avoid circular import, the import occurs here
10521076
from botorch.generation.gen import gen_candidates_scipy
10531077

10541078
X_top_k, f_top_k = gen_candidates_scipy(
1055-
X_top_k,
1079+
ics,
10561080
path_func,
10571081
lower_bounds=bounds[0],
10581082
upper_bounds=bounds[1],
@@ -1108,8 +1132,9 @@ def boltzmann_sample(
11081132
eta *= temp_decrease
11091133
weights = torch.exp(eta * norm_weights)
11101134

1135+
# squeeze in case of m = 1 (mono-output provided as batch_size x N x 1)
11111136
return batched_multinomial(
1112-
weights=weights, num_samples=num_samples, replacement=replacement
1137+
weights=weights.squeeze(-1), num_samples=num_samples, replacement=replacement
11131138
)
11141139

11151140

test/acquisition/test_utils.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from botorch.exceptions.warnings import BotorchWarning
3636
from botorch.models import SingleTaskGP
37+
from botorch.utils.test_helpers import get_fully_bayesian_model, get_model
3738
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
3839
from gpytorch.distributions import MultivariateNormal
3940

@@ -417,17 +418,14 @@ def test_project_to_sample_points(self):
417418

418419

419420
class TestGetOptimalSamples(BotorchTestCase):
420-
def test_get_optimal_samples(self):
421-
dims = 3
422-
dtype = torch.float64
421+
def _test_get_optimal_samples_base(self, model):
422+
dims = model.train_inputs[0].shape[1]
423+
dtype = model.train_targets.dtype
424+
batch_shape = model.batch_shape
423425
for_testing_speed_kwargs = {"raw_samples": 20, "num_restarts": 2}
424426
num_optima = 7
425-
batch_shape = (3,)
426427

427428
bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T
428-
X = torch.rand(*batch_shape, 4, dims, dtype=dtype)
429-
Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype)
430-
model = SingleTaskGP(train_X=X, train_Y=Y)
431429
posterior_transform = ScalarizedPosteriorTransform(
432430
weights=torch.ones(1, dtype=dtype)
433431
)
@@ -442,6 +440,7 @@ def test_get_optimal_samples(self):
442440
num_optima=num_optima,
443441
**for_testing_speed_kwargs,
444442
)
443+
445444
correct_X_shape = (num_optima,) + batch_shape + (dims,)
446445
correct_f_shape = (num_optima,) + batch_shape + (1,)
447446
self.assertEqual(X_opt_def.shape, correct_X_shape)
@@ -523,6 +522,22 @@ def test_get_optimal_samples(self):
523522
**for_testing_speed_kwargs,
524523
)
525524

525+
def test_optimal_samples(self):
526+
dims = 3
527+
dtype = torch.float64
528+
X = torch.rand(4, dims, dtype=dtype)
529+
Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype)
530+
model = get_model(train_X=X, train_Y=Y)
531+
self._test_get_optimal_samples_base(model)
532+
fully_bayesian_model = get_fully_bayesian_model(
533+
train_X=X,
534+
train_Y=Y,
535+
num_models=4,
536+
standardize_model=True,
537+
infer_noise=True,
538+
)
539+
self._test_get_optimal_samples_base(fully_bayesian_model)
540+
526541

527542
class TestPreferenceUtils(BotorchTestCase):
528543
def test_repeat_to_match_aug_dim(self):

test/utils/test_sampling.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,13 @@ def test_optimize_posterior_samples(self):
578578
dims = 2
579579
dtype = torch.float64
580580
eps = 1e-4
581-
for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4}
582-
nums_optima = (1, 7)
583-
batch_shapes = ((), (2,), (3, 2))
581+
for_testing_speed_kwargs = {
582+
"raw_samples": 64,
583+
"num_restarts": 2,
584+
"options": {"eta": 10},
585+
}
586+
nums_optima = (1, 5)
587+
batch_shapes = ((), (3,))
584588
posterior_transforms = (
585589
None,
586590
ScalarizedPosteriorTransform(weights=-torch.ones(1, dtype=dtype)),
@@ -589,16 +593,19 @@ def test_optimize_posterior_samples(self):
589593
nums_optima, batch_shapes, posterior_transforms
590594
):
591595
bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T
592-
X = torch.rand(*batch_shape, 4, dims, dtype=dtype)
596+
X = torch.rand(*batch_shape, 3, dims, dtype=dtype)
593597
Y = torch.pow(X - 0.5, 2).sum(dim=-1, keepdim=True)
594598

595599
# having a noiseless model all but guarantees that the found optima
596600
# will be better than the observations
597-
model = SingleTaskGP(X, Y, torch.full_like(Y, eps))
601+
model = SingleTaskGP(
602+
train_X=X, train_Y=Y, train_Yvar=torch.full_like(Y, eps)
603+
)
598604
model.covar_module.lengthscale = 0.5
599605
paths = get_matheron_path_model(
600606
model=model, sample_shape=torch.Size([num_optima])
601607
)
608+
602609
X_opt, f_opt = optimize_posterior_samples(
603610
paths=paths,
604611
bounds=bounds,
@@ -616,8 +623,6 @@ def test_optimize_posterior_samples(self):
616623
self.assertTrue(torch.all(X_opt >= bounds[0]))
617624
self.assertTrue(torch.all(X_opt <= bounds[1]))
618625

619-
# Check that the all found optima are larger than the observations
620-
# This is not 100% deterministic, but just about.
621626
Y_queries = paths(X)
622627
# this is when we negate, so the values should be smaller
623628
if posterior_transform:
@@ -642,7 +647,7 @@ def test_optimize_posterior_samples_multi_objective(self):
642647
dims = 2
643648
dtype = torch.float64
644649
eps = 1e-4
645-
for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4}
650+
for_testing_speed_kwargs = {"raw_samples": 64, "num_restarts": 2}
646651
num_optima = 5
647652
batch_shape = (3,)
648653

0 commit comments

Comments
 (0)