Skip to content

Commit 9a93afb

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Implements ExpectationPosteriorTransform (#903)
Summary: Pull Request resolved: #903 Implements `ExpectationPosteriorTransform`, which transforms the `batch x (q * n_w) x m` posterior to a `batch x q x m` posterior of the expectation over the `n_w` points. Unlike the `RiskMeasureMCObjective`, this avoids the posterior sampling over `q * n_w` points, which leads to significant speed-ups for large `q * n_w`. Reviewed By: Balandat Differential Revision: D29277116 fbshipit-source-id: a6be1c32d0343e6b1c99d2e76facc1c8d5b22d42
1 parent ce4900c commit 9a93afb

File tree

4 files changed

+263
-2
lines changed

4 files changed

+263
-2
lines changed

botorch/acquisition/multi_objective/multi_output_risk_measures.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,13 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
9595

9696

9797
class MultiOutputExpectation(MultiOutputRiskMeasureMCObjective):
98-
r"""A multi-output MC expectation risk measure."""
98+
r"""A multi-output MC expectation risk measure.
99+
100+
For unconstrained problems, we recommend using the `ExpectationPosteriorTransform`
101+
instead. `ExpectationPosteriorTransform` directly transforms the posterior
102+
distribution over `q * n_w` to a posterior of `q` expectations, significantly
103+
reducing the cost of posterior sampling as a result.
104+
"""
99105

100106
def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
101107
r"""Calculate the expectation of the given samples. Expectation is

botorch/acquisition/objective.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
from typing import Callable, List, Optional
1717

1818
import torch
19+
from botorch.exceptions.errors import UnsupportedError
1920
from botorch.models.model import Model
2021
from botorch.posteriors.gpytorch import GPyTorchPosterior, scalarize_posterior
2122
from botorch.posteriors.posterior import Posterior
2223
from botorch.sampling import IIDNormalSampler, MCSampler
2324
from botorch.utils import apply_constraints
25+
from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal
26+
from gpytorch.lazy import lazify
2427
from torch import Tensor
2528
from torch.nn import Module
2629

@@ -137,6 +140,110 @@ def __init__(self, weights: Tensor, offset: float = 0.0) -> None:
137140
super().__init__(weights=weights, offset=offset)
138141

139142

143+
class ExpectationPosteriorTransform(PosteriorTransform):
144+
r"""Transform the `batch x (q * n_w) x m` posterior into a `batch x q x m`
145+
posterior of the expectation. The expectation is calculated over each
146+
consecutive `n_w` block of points in the posterior.
147+
148+
This is intended for use with `InputPerturbation` or `AppendFeatures` for
149+
optimizing the expectation over `n_w` points. This should not be used when
150+
there are constraints present, since this does not take into account
151+
the feasibility of the objectives.
152+
153+
Note: This is different than `ScalarizedPosteriorTransform` in that
154+
this operates over the q-batch dimension.
155+
"""
156+
157+
def __init__(self, n_w: int, weights: Optional[Tensor] = None) -> None:
158+
r"""A posterior transform calculating the expectation over the q-batch
159+
dimension.
160+
161+
Args:
162+
n_w: The number of points in the q-batch of the posterior to compute
163+
the expectation over. This corresponds to the size of the
164+
`feature_set` of `AppendFeatures` or the size of the `perturbation_set`
165+
of `InputPerturbation`.
166+
weights: An optional `n_w x m`-dim tensor of weights. Can be used to
167+
compute a weighted expectation. Weights are normalized before use.
168+
"""
169+
super().__init__()
170+
if weights is not None:
171+
if weights.dim() != 2 or weights.shape[0] != n_w:
172+
raise ValueError("`weights` must be a tensor of size `n_w x m`.")
173+
if torch.any(weights < 0):
174+
raise ValueError("`weights` must be non-negative.")
175+
else:
176+
weights = torch.ones(n_w, 1)
177+
# Normalize the weights.
178+
weights = weights / weights.sum(dim=0)
179+
self.register_buffer("weights", weights)
180+
self.n_w = n_w
181+
182+
def evaluate(self, Y: Tensor) -> Tensor:
183+
r"""Evaluate the expectation of a set of outcomes.
184+
185+
Args:
186+
Y: A `batch_shape x (q * n_w) x m`-dim tensor of outcomes.
187+
188+
Returns:
189+
A `batch_shape x q x m`-dim tensor of expectation outcomes.
190+
"""
191+
batch_shape, m = Y.shape[:-2], Y.shape[-1]
192+
weighted_Y = Y.view(*batch_shape, -1, self.n_w, m) * self.weights.to(Y)
193+
return weighted_Y.sum(dim=-2)
194+
195+
def forward(self, posterior: GPyTorchPosterior) -> GPyTorchPosterior:
196+
r"""Compute the posterior of the expectation.
197+
198+
Args:
199+
posterior: An `m`-outcome joint posterior over `q * n_w` points.
200+
201+
Returns:
202+
An `m`-outcome joint posterior over `q` expectations.
203+
"""
204+
org_mvn = posterior.mvn
205+
if getattr(org_mvn, "_interleaved", False):
206+
raise UnsupportedError(
207+
"`ExpectationPosteriorTransform` does not support "
208+
"interleaved posteriors."
209+
)
210+
# Initialize the weight matrix of shape compatible with the mvn.
211+
org_event_shape = org_mvn.event_shape
212+
batch_shape = org_mvn.batch_shape
213+
q = org_event_shape[0] // self.n_w
214+
m = 1 if len(org_event_shape) == 1 else org_event_shape[-1]
215+
tkwargs = {"device": org_mvn.loc.device, "dtype": org_mvn.loc.dtype}
216+
weights = torch.zeros(q * m, q * self.n_w * m, **tkwargs)
217+
# Make sure self.weights has the correct dtype/device and shape.
218+
self.weights = self.weights.to(org_mvn.loc).expand(self.n_w, m)
219+
# Fill in the non-zero entries of the weight matrix.
220+
# We want each row to have non-zero weights for the corresponding
221+
# `n_w` sized diagonal. The `m` outcomes are not interleaved.
222+
for i in range(q * m):
223+
weights[i, self.n_w * i : self.n_w * (i + 1)] = self.weights[:, i // q]
224+
# Trasform the mean.
225+
new_loc = (
226+
(weights @ org_mvn.loc.unsqueeze(-1))
227+
.view(*batch_shape, m, q)
228+
.transpose(-1, -2)
229+
)
230+
# Transform the covariance matrix.
231+
org_cov = (
232+
org_mvn.lazy_covariance_matrix
233+
if org_mvn.islazy
234+
else org_mvn.covariance_matrix
235+
)
236+
new_cov = weights @ (org_cov @ weights.t())
237+
if m == 1:
238+
new_mvn = MultivariateNormal(new_loc.squeeze(-1), lazify(new_cov))
239+
else:
240+
# Using MTMVN since we pass a single loc and covar for all `m` outputs.
241+
new_mvn = MultitaskMultivariateNormal(
242+
new_loc, lazify(new_cov), interleaved=False
243+
)
244+
return GPyTorchPosterior(mvn=new_mvn)
245+
246+
140247
class MCAcquisitionObjective(Module, ABC):
141248
r"""Abstract base class for MC-based objectives.
142249

botorch/acquisition/risk_measures.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,13 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
228228

229229

230230
class Expectation(RiskMeasureMCObjective):
231-
r"""The expectation risk measure."""
231+
r"""The expectation risk measure.
232+
233+
For unconstrained problems, we recommend using the `ExpectationPosteriorTransform`
234+
instead. `ExpectationPosteriorTransform` directly transforms the posterior
235+
distribution over `q * n_w` to a posterior of `q` expectations, significantly
236+
reducing the cost of posterior sampling as a result.
237+
"""
232238

233239
def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
234240
r"""Calculate the expectation corresponding to the given samples.

test/acquisition/test_objective.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,23 @@
1212
from botorch.acquisition import LearnedObjective
1313
from botorch.acquisition.objective import (
1414
ConstrainedMCObjective,
15+
ExpectationPosteriorTransform,
1516
GenericMCObjective,
1617
IdentityMCObjective,
1718
LinearMCObjective,
1819
MCAcquisitionObjective,
1920
PosteriorTransform,
2021
ScalarizedPosteriorTransform,
2122
)
23+
from botorch.exceptions.errors import UnsupportedError
2224
from botorch.models.deterministic import PosteriorMeanModel
2325
from botorch.models.pairwise_gp import PairwiseGP
26+
from botorch.posteriors import GPyTorchPosterior
2427
from botorch.sampling.samplers import SobolQMCNormalSampler
2528
from botorch.utils import apply_constraints
2629
from botorch.utils.testing import _get_test_posterior, BotorchTestCase
30+
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
31+
from gpytorch.lazy import lazify
2732
from torch import Tensor
2833

2934

@@ -83,6 +88,143 @@ def test_scalarized_posterior_transform(self):
8388
self.assertTrue(torch.equal(val, val_expected))
8489

8590

91+
class TestExpectationPosteriorTransform(BotorchTestCase):
92+
def test_init(self):
93+
# Without weights.
94+
tf = ExpectationPosteriorTransform(n_w=5)
95+
self.assertEqual(tf.n_w, 5)
96+
self.assertTrue(torch.allclose(tf.weights, torch.ones(5, 1) * 0.2))
97+
# Errors with weights.
98+
with self.assertRaisesRegex(ValueError, "a tensor of size"):
99+
ExpectationPosteriorTransform(n_w=3, weights=torch.ones(5, 1))
100+
with self.assertRaisesRegex(ValueError, "non-negative"):
101+
ExpectationPosteriorTransform(n_w=3, weights=-torch.ones(3, 1))
102+
# Successful init with weights.
103+
weights = torch.tensor([[1.0, 2.0], [2.0, 4.0], [3.0, 6.0]])
104+
tf = ExpectationPosteriorTransform(n_w=3, weights=weights)
105+
self.assertTrue(torch.allclose(tf.weights, weights / torch.tensor([6.0, 12.0])))
106+
107+
def test_evaluate(self):
108+
for dtype in (torch.float, torch.double):
109+
tkwargs = {"dtype": dtype, "device": self.device}
110+
# Without weights.
111+
tf = ExpectationPosteriorTransform(n_w=3)
112+
Y = torch.rand(3, 6, 2, **tkwargs)
113+
self.assertTrue(
114+
torch.allclose(tf.evaluate(Y), Y.view(3, 2, 3, 2).mean(dim=-2))
115+
)
116+
# With weights - weights intentionally doesn't use tkwargs.
117+
weights = torch.tensor([[1.0, 2.0], [2.0, 1.0]])
118+
tf = ExpectationPosteriorTransform(n_w=2, weights=weights)
119+
expected = (Y.view(3, 3, 2, 2) * weights.to(Y)).sum(dim=-2) / 3.0
120+
self.assertTrue(torch.allclose(tf.evaluate(Y), expected))
121+
122+
def test_expectation_posterior_transform(self):
123+
tkwargs = {"dtype": torch.float, "device": self.device}
124+
# Without weights, simple expectation, single output, no batch.
125+
# q = 2, n_w = 3.
126+
org_loc = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], **tkwargs)
127+
org_covar = torch.tensor(
128+
[
129+
[1.0, 0.8, 0.7, 0.3, 0.2, 0.1],
130+
[0.8, 1.0, 0.9, 0.25, 0.15, 0.1],
131+
[0.7, 0.9, 1.0, 0.2, 0.2, 0.05],
132+
[0.3, 0.25, 0.2, 1.0, 0.7, 0.6],
133+
[0.2, 0.15, 0.2, 0.7, 1.0, 0.7],
134+
[0.1, 0.1, 0.05, 0.6, 0.7, 1.0],
135+
],
136+
**tkwargs
137+
)
138+
org_mvn = MultivariateNormal(org_loc, lazify(org_covar))
139+
org_post = GPyTorchPosterior(mvn=org_mvn)
140+
tf = ExpectationPosteriorTransform(n_w=3)
141+
tf_post = tf(org_post)
142+
self.assertIsInstance(tf_post, GPyTorchPosterior)
143+
self.assertEqual(tf_post.sample().shape, torch.Size([1, 2, 1]))
144+
tf_mvn = tf_post.mvn
145+
self.assertIsInstance(tf_mvn, MultivariateNormal)
146+
expected_loc = torch.tensor([2.0, 5.0], **tkwargs)
147+
# This is the average of each 3 x 3 block.
148+
expected_covar = torch.tensor([[0.8667, 0.1722], [0.1722, 0.7778]], **tkwargs)
149+
self.assertTrue(torch.allclose(tf_mvn.loc, expected_loc))
150+
self.assertTrue(
151+
torch.allclose(tf_mvn.covariance_matrix, expected_covar, atol=1e-3)
152+
)
153+
154+
# With weights, 2 outputs, batched.
155+
tkwargs = {"dtype": torch.double, "device": self.device}
156+
# q = 2, n_w = 2, m = 2, leading to 8 values for loc and 8x8 cov.
157+
org_loc = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], **tkwargs)
158+
# We have 2 4x4 matrices with 0s as filler. Each block is for one outcome.
159+
# Each 2x2 sub block corresponds to `n_w`.
160+
org_covar = torch.tensor(
161+
[
162+
[1.0, 0.8, 0.3, 0.2, 0.0, 0.0, 0.0, 0.0],
163+
[0.8, 1.4, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0],
164+
[0.3, 0.2, 1.2, 0.5, 0.0, 0.0, 0.0, 0.0],
165+
[0.2, 0.1, 0.5, 1.0, 0.0, 0.0, 0.0, 0.0],
166+
[0.0, 0.0, 0.0, 0.0, 1.0, 0.7, 0.4, 0.3],
167+
[0.0, 0.0, 0.0, 0.0, 0.7, 0.8, 0.3, 0.2],
168+
[0.0, 0.0, 0.0, 0.0, 0.4, 0.3, 1.4, 0.5],
169+
[0.0, 0.0, 0.0, 0.0, 0.3, 0.2, 0.5, 1.2],
170+
],
171+
**tkwargs
172+
)
173+
# Making it batched by adding two more batches, mostly the same.
174+
org_loc = org_loc.repeat(3, 1)
175+
org_loc[1] += 100
176+
org_loc[2] += 1000
177+
org_covar = org_covar.repeat(3, 1, 1)
178+
# Construct the transform with weights.
179+
weights = torch.tensor([[1.0, 3.0], [2.0, 1.0]])
180+
tf = ExpectationPosteriorTransform(n_w=2, weights=weights)
181+
# Construct the posterior.
182+
org_mvn = MultitaskMultivariateNormal(
183+
# The return of mvn.loc and the required input are different.
184+
# We constructed it according to the output of mvn.loc,
185+
# reshaping here to have the required `b x n x t` shape.
186+
org_loc.view(3, 2, 4).transpose(-2, -1),
187+
lazify(org_covar),
188+
interleaved=True, # To test the error.
189+
)
190+
org_post = GPyTorchPosterior(mvn=org_mvn)
191+
# Error if interleaved.
192+
with self.assertRaisesRegex(UnsupportedError, "interleaved"):
193+
tf(org_post)
194+
# Construct the non-interleaved posterior.
195+
org_mvn = MultitaskMultivariateNormal(
196+
org_loc.view(3, 2, 4).transpose(-2, -1),
197+
lazify(org_covar),
198+
interleaved=False,
199+
)
200+
org_post = GPyTorchPosterior(mvn=org_mvn)
201+
self.assertTrue(torch.equal(org_mvn.loc, org_loc))
202+
tf_post = tf(org_post)
203+
self.assertIsInstance(tf_post, GPyTorchPosterior)
204+
self.assertEqual(tf_post.sample().shape, torch.Size([1, 3, 2, 2]))
205+
tf_mvn = tf_post.mvn
206+
self.assertIsInstance(tf_mvn, MultitaskMultivariateNormal)
207+
expected_loc = torch.tensor([[1.6667, 3.6667, 5.25, 7.25]], **tkwargs).repeat(
208+
3, 1
209+
)
210+
expected_loc[1] += 100
211+
expected_loc[2] += 1000
212+
# This is the weighted average of each 2 x 2 block.
213+
expected_covar = torch.tensor(
214+
[
215+
[1.0889, 0.1667, 0.0, 0.0],
216+
[0.1667, 0.8, 0.0, 0.0],
217+
[0.0, 0.0, 0.875, 0.35],
218+
[0.0, 0.0, 0.35, 1.05],
219+
],
220+
**tkwargs
221+
).repeat(3, 1, 1)
222+
self.assertTrue(torch.allclose(tf_mvn.loc, expected_loc, atol=1e-3))
223+
self.assertTrue(
224+
torch.allclose(tf_mvn.covariance_matrix, expected_covar, atol=1e-3)
225+
)
226+
227+
86228
class TestMCAcquisitionObjective(BotorchTestCase):
87229
def test_abstract_raises(self):
88230
with self.assertRaises(TypeError):

0 commit comments

Comments
 (0)