Skip to content

Commit c431c6d

Browse files
TobyBoynefacebook-github-bot
authored andcommitted
Change how qNEHVI handles pending points (#2985)
Summary: ## Motivation Currently, qNEHVI proposes repeated experiments in a batch when initial pending points are passed. This PR changes how this class handles pending points - `X_pending` is now always populated, and only appended in the forward pass if those points have not yet been cached. See issue #2983 for further discussion. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #2985 Test Plan: I will rewrite the tests in `test/acquisition/multi_objective/test_monte_carlo.py` to ensure that they pass. Reviewed By: hvarfner Differential Revision: D80533943 Pulled By: Balandat fbshipit-source-id: c5f72413f5636bba69ac2088bc35f1bd819f0fc4
1 parent a5d74d9 commit c431c6d

File tree

4 files changed

+70
-65
lines changed

4 files changed

+70
-65
lines changed

botorch/acquisition/multi_objective/logei.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@
3939
from botorch.utils.transforms import (
4040
average_over_ensemble_models,
4141
concatenate_pending_points,
42-
is_ensemble,
43-
match_batch_shape,
4442
t_batch_mode_transform,
4543
)
4644
from torch import Tensor
@@ -439,32 +437,12 @@ def __init__(
439437
self.tau_max = tau_max
440438
self.fat = fat
441439

442-
@concatenate_pending_points
443440
@t_batch_mode_transform()
444441
@average_over_ensemble_models
445442
def forward(self, X: Tensor) -> Tensor:
446-
X_full = torch.cat([match_batch_shape(self.X_baseline, X), X], dim=-2)
447-
# NOTE: To ensure that we correctly sample `f(X)` from the joint distribution
448-
# `f((X_baseline, X)) ~ P(f | D)`, it is critical to compute the joint posterior
449-
# over X *and* X_baseline -- which also contains pending points whenever there
450-
# are any -- since the baseline and pending values `f(X_baseline)` are
451-
# generally pre-computed and cached before the `forward` call, see the docs of
452-
# `cache_pending` for details.
453-
# TODO: Improve the efficiency by not re-computing the X_baseline-X_baseline
454-
# covariance matrix, but only the covariance of
455-
# 1) X and X, and
456-
# 2) X and X_baseline.
457-
posterior = self.model.posterior(X_full)
458-
# Account for possible one-to-many transform and the model batch dimensions in
459-
# ensemble models.
460-
event_shape_lag = 1 if is_ensemble(self.model) else 2
461-
n_w = (
462-
posterior._extended_shape()[X_full.dim() - event_shape_lag]
463-
// X_full.shape[-2]
464-
)
465-
q_in = X.shape[-2] * n_w
466-
self._set_sampler(q_in=q_in, posterior=posterior)
467-
samples = self._get_f_X_samples(posterior=posterior, q_in=q_in)
443+
# Get samples from the posterior, and manually concatenate pending points that
444+
# have not yet been cached. Shared with qNEHVI.
445+
samples, X = self._compute_posterior_samples_and_concat_pending(X)
468446
# Add previous nehvi from pending points.
469447
nehvi = self._compute_log_qehvi(samples=samples, X=X)
470448
if self.incremental_nehvi:

botorch/acquisition/multi_objective/monte_carlo.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@
4545
from botorch.utils.transforms import (
4646
average_over_ensemble_models,
4747
concatenate_pending_points,
48-
is_ensemble,
49-
match_batch_shape,
5048
t_batch_mode_transform,
5149
)
5250
from torch import Tensor
@@ -349,31 +347,11 @@ def __init__(
349347
)
350348
self.fat = fat
351349

352-
@concatenate_pending_points
353350
@t_batch_mode_transform()
354351
@average_over_ensemble_models
355352
def forward(self, X: Tensor) -> Tensor:
356-
X_full = torch.cat([match_batch_shape(self.X_baseline, X), X], dim=-2)
357-
# NOTE: To ensure that we correctly sample `f(X)` from the joint distribution
358-
# `f((X_baseline, X)) ~ P(f | D)`, it is critical to compute the joint posterior
359-
# over X *and* X_baseline -- which also contains pending points whenever there
360-
# are any -- since the baseline and pending values `f(X_baseline)` are
361-
# generally pre-computed and cached before the `forward` call, see the docs of
362-
# `cache_pending` for details.
363-
# TODO: Improve the efficiency by not re-computing the X_baseline-X_baseline
364-
# covariance matrix, but only the covariance of
365-
# 1) X and X, and
366-
# 2) X and X_baseline.
367-
posterior = self.model.posterior(X_full)
368-
# Account for possible one-to-many transform and the MCMC batch dimension in
369-
# `SaasFullyBayesianSingleTaskGP`
370-
event_shape_lag = 1 if is_ensemble(self.model) else 2
371-
n_w = (
372-
posterior._extended_shape()[X_full.dim() - event_shape_lag]
373-
// X_full.shape[-2]
374-
)
375-
q_in = X.shape[-2] * n_w
376-
self._set_sampler(q_in=q_in, posterior=posterior)
377-
samples = self._get_f_X_samples(posterior=posterior, q_in=q_in)
353+
# Get samples from the posterior, and manually concatenate pending points that
354+
# have not yet been cached. Shared with qLogNEHVI.
355+
samples, X = self._compute_posterior_samples_and_concat_pending(X)
378356
# Add previous nehvi from pending points.
379357
return self._compute_qehvi(samples=samples, X=X) + self._prev_nehvi

botorch/utils/multi_objective/hypervolume.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@
5858
)
5959
from botorch.utils.objective import compute_feasibility_indicator
6060
from botorch.utils.torch import BufferDict
61+
from botorch.utils.transforms import is_ensemble, match_batch_shape
6162
from torch import Tensor
6263

64+
6365
MIN_Y_RANGE = 1e-7
6466

6567

@@ -793,7 +795,7 @@ def set_X_pending(self, X_pending: Tensor | None = None) -> None:
793795
BotorchWarning,
794796
stacklevel=2,
795797
)
796-
X_pending = X_pending.detach().clone()
798+
self.X_pending = X_pending.detach().clone()
797799
if self.cache_pending:
798800
X_baseline = torch.cat([self._X_baseline, X_pending], dim=-2)
799801
# Number of new points is the total number of points minus
@@ -812,16 +814,9 @@ def set_X_pending(self, X_pending: Tensor | None = None) -> None:
812814
.clamp_min(0.0)
813815
.mean()
814816
)
815-
# Set to None so that pending points are not concatenated in
816-
# forward.
817-
self.X_pending = None
818817
# Set q_in=-1 to so that self.sampler is updated at the next
819818
# forward call.
820819
self.q_in = -1
821-
else:
822-
self.X_pending = X_pending[-num_new_points:]
823-
else:
824-
self.X_pending = X_pending
825820

826821
@property
827822
def _hypervolumes(self) -> Tensor:
@@ -836,6 +831,58 @@ def _hypervolumes(self) -> Tensor:
836831
.view(self._batch_sample_shape)
837832
)
838833

834+
def _compute_posterior_samples_and_concat_pending(
835+
self, X: Tensor
836+
) -> tuple[Tensor, Tensor]:
837+
r"""Get samples from the posterior, and concatenate uncached pending points.
838+
839+
Args:
840+
X: `batch_shape x q x d` X Tensor pased into the `forward` method of an acqf
841+
842+
Returns:
843+
A tuple containing samples of the latent function from the posterior, and
844+
the `batch_shape x (q + num_uncached_pending) x d` X tensor including any
845+
pending observations that have not been cached.
846+
"""
847+
# Manually concatenate pending points only if:
848+
# - pending points are not cached, or
849+
# - number of pending points is less than max_iep
850+
if self.X_pending is not None:
851+
num_pending = self.X_pending.shape[-2]
852+
num_X_baseline = self._X_baseline.shape[-2]
853+
num_X_baseline_and_cached_pending = self.X_baseline.shape[-2]
854+
num_uncached_pending = (
855+
(num_pending + num_X_baseline - num_X_baseline_and_cached_pending)
856+
if self.cache_pending
857+
else num_pending
858+
)
859+
X_pending_uncached = self.X_pending[
860+
..., num_pending - num_uncached_pending :, :
861+
]
862+
X = torch.cat([X, match_batch_shape(X_pending_uncached, X)], dim=-2)
863+
X_full = torch.cat([match_batch_shape(self.X_baseline, X), X], dim=-2)
864+
# NOTE: To ensure that we correctly sample `f(X)` from the joint distribution
865+
# `f((X_baseline, X)) ~ P(f | D)`, it is critical to compute the joint posterior
866+
# over X *and* X_baseline -- which also contains pending points whenever there
867+
# are any -- since the baseline and pending values `f(X_baseline)` are
868+
# generally pre-computed and cached before the `forward` call, see the docs of
869+
# `cache_pending` for details.
870+
# TODO: Improve the efficiency by not re-computing the X_baseline-X_baseline
871+
# covariance matrix, but only the covariance of
872+
# 1) X and X, and
873+
# 2) X and X_baseline.
874+
posterior = self.model.posterior(X_full)
875+
# Account for possible one-to-many transform and the MCMC batch dimension in
876+
# `SaasFullyBayesianSingleTaskGP`
877+
event_shape_lag = 1 if is_ensemble(self.model) else 2
878+
n_w = (
879+
posterior._extended_shape()[X_full.dim() - event_shape_lag]
880+
// X_full.shape[-2]
881+
)
882+
q_in = X.shape[-2] * n_w
883+
self._set_sampler(q_in=q_in, posterior=posterior)
884+
return self._get_f_X_samples(posterior=posterior, q_in=q_in), X
885+
839886

840887
def get_hypervolume_maximizing_subset(
841888
n: int, Y: Tensor, ref_point: Tensor

test/acquisition/multi_objective/test_monte_carlo.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,7 @@ def _test_qnehvi_with_CBD(
10541054
acqf.set_X_pending(X_pending)
10551055
if not incremental_nehvi:
10561056
self.assertAllClose(expected_val, acqf._prev_nehvi)
1057-
self.assertIsNone(acqf.X_pending)
1057+
self.assertTrue(torch.all(acqf.X_pending == X_pending))
10581058
# check that X_baseline has been updated
10591059
self.assertTrue(torch.equal(acqf.X_baseline[:-1], acqf._X_baseline))
10601060
self.assertTrue(torch.equal(acqf.X_baseline[-1:], X_pending))
@@ -1112,7 +1112,7 @@ def _test_qnehvi_with_CBD(
11121112
)
11131113
mm._posterior._samples = mm._posterior._samples.squeeze(0)
11141114
acqf.set_X_pending(X_pending2)
1115-
self.assertIsNone(acqf.X_pending)
1115+
self.assertTrue(torch.all(acqf.X_pending == X_pending2))
11161116
# check that X_baseline has been updated
11171117
self.assertTrue(torch.equal(acqf.X_baseline[:-2], acqf._X_baseline))
11181118
self.assertTrue(torch.equal(acqf.X_baseline[-2:], X_pending2))
@@ -1129,7 +1129,9 @@ def _test_qnehvi_with_CBD(
11291129
acqf.set_X_pending(
11301130
torch.cat([X_pending2, X_pending2], dim=0).requires_grad_(True)
11311131
)
1132-
self.assertIsNone(acqf.X_pending)
1132+
self.assertTrue(
1133+
torch.all(acqf.X_pending == torch.cat([X_pending2, X_pending2], dim=0))
1134+
)
11331135
self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1)
11341136

11351137
# test max iep
@@ -1161,10 +1163,10 @@ def _test_qnehvi_with_CBD(
11611163
new_Y2,
11621164
]
11631165
)
1164-
# check that after second pending point is added, X_pending is set to None
1165-
# and the pending points are included in the box decompositions
1166+
# check that after second pending point is added, X_pending still includes
1167+
# pending points, and the pending points are included in the box decompositions
11661168
acqf.set_X_pending(X_pending2)
1167-
self.assertIsNone(acqf.X_pending)
1169+
self.assertTrue(torch.all(acqf.X_pending == X_pending2))
11681170
acqf_pareto_Y = acqf.partitioning.pareto_Y[0]
11691171
self.assertTrue(torch.equal(acqf_pareto_Y[:-2], expected_pareto_Y))
11701172
self.assertTrue(torch.equal(acqf_pareto_Y[-2:], expected_new_Y2))
@@ -1294,7 +1296,7 @@ def test_constrained_q_noisy_expected_hypervolume_improvement(self) -> None:
12941296
def test_constrained_q_log_noisy_expected_hypervolume_improvement(self) -> None:
12951297
for dtype, fat in product(
12961298
(torch.float, torch.double),
1297-
(True, False),
1299+
(False, True),
12981300
):
12991301
with self.subTest(dtype=dtype, fat=fat):
13001302
self._test_constrained_q_noisy_expected_hypervolume_improvement(

0 commit comments

Comments
 (0)