Skip to content

Commit 52959e6

Browse files
danielrjiangfacebook-github-bot
authored andcommitted
approximate qPI using MVNXPB (#1684)
Summary: Pull Request resolved: #1684 This is work by jiayuewan during his internship. I am simply moving it to OSS. Reviewed By: Balandat Differential Revision: D43337388 fbshipit-source-id: f9a24ce52fa2446b33b628e83a61a968c9062b25
1 parent 8db9c7b commit 52959e6

File tree

3 files changed

+225
-0
lines changed

3 files changed

+225
-0
lines changed

botorch/acquisition/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
NoisyExpectedImprovement,
2020
PosteriorMean,
2121
ProbabilityOfImprovement,
22+
qAnalyticProbabilityOfImprovement,
2223
UpperConfidenceBound,
2324
)
2425
from botorch.acquisition.cost_aware import (
@@ -77,6 +78,7 @@
7778
"ProbabilityOfImprovement",
7879
"ProximalAcquisitionFunction",
7980
"UpperConfidenceBound",
81+
"qAnalyticProbabilityOfImprovement",
8082
"qExpectedImprovement",
8183
"qKnowledgeGradient",
8284
"MaxValueBase",

botorch/acquisition/analytic.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from contextlib import nullcontext
1919
from copy import deepcopy
20+
2021
from typing import Dict, Optional, Tuple, Union
2122

2223
import torch
@@ -27,6 +28,7 @@
2728
from botorch.models.gpytorch import GPyTorchModel
2829
from botorch.models.model import Model
2930
from botorch.utils.constants import get_constants_like
31+
from botorch.utils.probability import MVNXPB
3032
from botorch.utils.probability.utils import (
3133
log_ndtr as log_Phi,
3234
log_phi,
@@ -37,6 +39,7 @@
3739
from botorch.utils.safe_math import log1mexp, logmeanexp
3840
from botorch.utils.transforms import convert_to_target_pre_hook, t_batch_mode_transform
3941
from torch import Tensor
42+
from torch.nn.functional import pad
4043

4144
_sqrt_2pi = math.sqrt(2 * math.pi)
4245
# the following two numbers are needed for _log_ei_helper
@@ -231,6 +234,70 @@ def forward(self, X: Tensor) -> Tensor:
231234
return Phi(u)
232235

233236

237+
class qAnalyticProbabilityOfImprovement(AnalyticAcquisitionFunction):
238+
r"""Approximate, single-outcome batch Probability of Improvement using MVNXPB.
239+
240+
This implementation uses MVNXPB, a bivariate conditioning algorithm for
241+
approximating P(a <= Y <= b) for multivariate normal Y.
242+
See [Trinh2015bivariate]_. This (analytic) approximate q-PI is given by
243+
`approx-qPI(X) = P(max Y >= best_f) = 1 - P(Y < best_f), Y ~ f(X),
244+
X = (x_1,...,x_q)`, where `P(Y < best_f)` is estimated using MVNXPB.
245+
"""
246+
247+
def __init__(
248+
self,
249+
model: Model,
250+
best_f: Union[float, Tensor],
251+
posterior_transform: Optional[PosteriorTransform] = None,
252+
maximize: bool = True,
253+
**kwargs,
254+
) -> None:
255+
"""qPI using an analytic approximation.
256+
257+
Args:
258+
model: A fitted single-outcome model.
259+
best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
260+
the best function value observed so far (assumed noiseless).
261+
posterior_transform: A PosteriorTransform. If using a multi-output model,
262+
a PosteriorTransform that transforms the multi-output posterior into a
263+
single-output posterior is required.
264+
maximize: If True, consider the problem a maximization problem.
265+
"""
266+
super().__init__(model=model, posterior_transform=posterior_transform, **kwargs)
267+
self.maximize = maximize
268+
if not torch.is_tensor(best_f):
269+
best_f = torch.tensor(best_f)
270+
self.register_buffer("best_f", best_f)
271+
272+
@t_batch_mode_transform()
273+
def forward(self, X: Tensor) -> Tensor:
274+
"""Evaluate approximate qPI on the candidate set X.
275+
276+
Args:
277+
X: A `batch_shape x q x d`-dim Tensor of t-batches with `q` `d`-dim design
278+
points each
279+
280+
Returns:
281+
A `batch_shape`-dim Tensor of approximate Probability of Improvement values
282+
at the given design points `X`, where `batch_shape'` is the broadcasted
283+
batch shape of model and input `X`.
284+
"""
285+
self.best_f = self.best_f.to(X)
286+
posterior = self.model.posterior(
287+
X=X, posterior_transform=self.posterior_transform
288+
)
289+
290+
covariance = posterior.distribution.covariance_matrix
291+
bounds = pad(
292+
(self.best_f.unsqueeze(-1) - posterior.distribution.mean).unsqueeze(-1),
293+
pad=(1, 0) if self.maximize else (0, 1),
294+
value=-float("inf") if self.maximize else float("inf"),
295+
)
296+
# 1 - P(no improvement over best_f)
297+
solver = MVNXPB(covariance_matrix=covariance, bounds=bounds)
298+
return -solver.solve().expm1()
299+
300+
234301
class ExpectedImprovement(AnalyticAcquisitionFunction):
235302
r"""Single-outcome Expected Improvement (analytic).
236303

test/acquisition/test_analytic.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import math
88

99
import torch
10+
from botorch.acquisition import qAnalyticProbabilityOfImprovement
1011
from botorch.acquisition.analytic import (
1112
_compute_log_prob_feas,
1213
_ei_helper,
@@ -362,6 +363,161 @@ def test_probability_of_improvement_batch(self):
362363
LogProbabilityOfImprovement(model=mm2, best_f=0.0)
363364

364365

366+
class TestqAnalyticProbabilityOfImprovement(BotorchTestCase):
367+
def test_q_analytic_probability_of_improvement(self):
368+
for dtype in (torch.float, torch.double):
369+
mean = torch.zeros(1, device=self.device, dtype=dtype)
370+
cov = torch.eye(n=1, device=self.device, dtype=dtype)
371+
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
372+
posterior = GPyTorchPosterior(mvn)
373+
mm = MockModel(posterior)
374+
375+
# basic test
376+
module = qAnalyticProbabilityOfImprovement(model=mm, best_f=1.96)
377+
X = torch.rand(1, 2, device=self.device, dtype=dtype)
378+
pi = module(X)
379+
pi_expected = torch.tensor(0.0250, device=self.device, dtype=dtype)
380+
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))
381+
382+
# basic test, maximize
383+
module = qAnalyticProbabilityOfImprovement(
384+
model=mm, best_f=1.96, maximize=False
385+
)
386+
X = torch.rand(1, 2, device=self.device, dtype=dtype)
387+
pi = module(X)
388+
pi_expected = torch.tensor(0.9750, device=self.device, dtype=dtype)
389+
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))
390+
391+
# basic test, posterior transform (single-output)
392+
mean = torch.ones(1, device=self.device, dtype=dtype)
393+
cov = torch.eye(n=1, device=self.device, dtype=dtype)
394+
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
395+
posterior = GPyTorchPosterior(mvn)
396+
mm = MockModel(posterior)
397+
weights = torch.tensor([0.5], device=self.device, dtype=dtype)
398+
transform = ScalarizedPosteriorTransform(weights)
399+
module = qAnalyticProbabilityOfImprovement(
400+
model=mm, best_f=0.0, posterior_transform=transform
401+
)
402+
X = torch.rand(1, 2, device=self.device, dtype=dtype)
403+
pi = module(X)
404+
pi_expected = torch.tensor(0.8413, device=self.device, dtype=dtype)
405+
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))
406+
407+
# basic test, posterior transform (multi-output)
408+
mean = torch.ones(1, 2, device=self.device, dtype=dtype)
409+
cov = torch.eye(n=2, device=self.device, dtype=dtype).unsqueeze(0)
410+
mvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov)
411+
posterior = GPyTorchPosterior(mvn)
412+
mm = MockModel(posterior)
413+
weights = torch.ones(2, device=self.device, dtype=dtype)
414+
transform = ScalarizedPosteriorTransform(weights)
415+
module = qAnalyticProbabilityOfImprovement(
416+
model=mm, best_f=0.0, posterior_transform=transform
417+
)
418+
X = torch.rand(1, 1, device=self.device, dtype=dtype)
419+
pi = module(X)
420+
pi_expected = torch.tensor(0.9214, device=self.device, dtype=dtype)
421+
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))
422+
423+
# basic test, q = 2
424+
mean = torch.zeros(2, device=self.device, dtype=dtype)
425+
cov = torch.eye(n=2, device=self.device, dtype=dtype)
426+
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
427+
posterior = GPyTorchPosterior(mvn)
428+
mm = MockModel(posterior)
429+
module = qAnalyticProbabilityOfImprovement(model=mm, best_f=1.96)
430+
X = torch.zeros(2, 2, device=self.device, dtype=dtype)
431+
pi = module(X)
432+
pi_expected = torch.tensor(0.049375, device=self.device, dtype=dtype)
433+
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))
434+
435+
def test_batch_q_analytic_probability_of_improvement(self):
436+
for dtype in (torch.float, torch.double):
437+
# test batch mode
438+
mean = torch.tensor([[0.0], [1.0]], device=self.device, dtype=dtype)
439+
cov = (
440+
torch.eye(n=1, device=self.device, dtype=dtype)
441+
.unsqueeze(0)
442+
.repeat(2, 1, 1)
443+
)
444+
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
445+
posterior = GPyTorchPosterior(mvn)
446+
mm = MockModel(posterior)
447+
module = qAnalyticProbabilityOfImprovement(model=mm, best_f=0)
448+
X = torch.rand(2, 1, 1, device=self.device, dtype=dtype)
449+
pi = module(X)
450+
pi_expected = torch.tensor([0.5, 0.8413], device=self.device, dtype=dtype)
451+
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))
452+
453+
# test batched model and best_f values
454+
mean = torch.zeros(2, 1, device=self.device, dtype=dtype)
455+
cov = (
456+
torch.eye(n=1, device=self.device, dtype=dtype)
457+
.unsqueeze(0)
458+
.repeat(2, 1, 1)
459+
)
460+
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
461+
posterior = GPyTorchPosterior(mvn)
462+
mm = MockModel(posterior)
463+
best_f = torch.tensor([0.0, -1.0], device=self.device, dtype=dtype)
464+
module = qAnalyticProbabilityOfImprovement(model=mm, best_f=best_f)
465+
X = torch.rand(2, 1, 1, device=self.device, dtype=dtype)
466+
pi = module(X)
467+
pi_expected = torch.tensor([[0.5, 0.8413]], device=self.device, dtype=dtype)
468+
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))
469+
470+
# test batched model, output transform (single output)
471+
mean = torch.tensor([[0.0], [1.0]], device=self.device, dtype=dtype)
472+
cov = (
473+
torch.eye(n=1, device=self.device, dtype=dtype)
474+
.unsqueeze(0)
475+
.repeat(2, 1, 1)
476+
)
477+
mvn = MultivariateNormal(mean=mean, covariance_matrix=cov)
478+
posterior = GPyTorchPosterior(mvn)
479+
mm = MockModel(posterior)
480+
weights = torch.tensor([0.5], device=self.device, dtype=dtype)
481+
transform = ScalarizedPosteriorTransform(weights)
482+
module = qAnalyticProbabilityOfImprovement(
483+
model=mm, best_f=0.0, posterior_transform=transform
484+
)
485+
X = torch.rand(2, 1, 2, device=self.device, dtype=dtype)
486+
pi = module(X)
487+
pi_expected = torch.tensor([0.5, 0.8413], device=self.device, dtype=dtype)
488+
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))
489+
490+
# test batched model, output transform (multiple output)
491+
mean = torch.tensor(
492+
[[[1.0, 1.0]], [[0.0, 1.0]]], device=self.device, dtype=dtype
493+
)
494+
cov = (
495+
torch.eye(n=2, device=self.device, dtype=dtype)
496+
.unsqueeze(0)
497+
.repeat(2, 1, 1)
498+
)
499+
mvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=cov)
500+
posterior = GPyTorchPosterior(mvn)
501+
mm = MockModel(posterior)
502+
weights = torch.ones(2, device=self.device, dtype=dtype)
503+
transform = ScalarizedPosteriorTransform(weights)
504+
module = qAnalyticProbabilityOfImprovement(
505+
model=mm, best_f=0.0, posterior_transform=transform
506+
)
507+
X = torch.rand(2, 1, 2, device=self.device, dtype=dtype)
508+
pi = module(X)
509+
pi_expected = torch.tensor(
510+
[0.9214, 0.7602], device=self.device, dtype=dtype
511+
)
512+
self.assertTrue(torch.allclose(pi, pi_expected, atol=1e-4))
513+
514+
# test bad posterior transform class
515+
with self.assertRaises(UnsupportedError):
516+
qAnalyticProbabilityOfImprovement(
517+
model=mm, best_f=0.0, posterior_transform=IdentityMCObjective()
518+
)
519+
520+
365521
class TestUpperConfidenceBound(BotorchTestCase):
366522
def test_upper_confidence_bound(self):
367523
for dtype in (torch.float, torch.double):

0 commit comments

Comments
 (0)