Skip to content

Commit 2fe7d00

Browse files
Balandatfacebook-github-bot
authored andcommitted
Introduce base_sample_shape property to Posterior objects (#718)
Summary: This allows base sample shapes that are different from the posterior's `event_shape`. This is relevant e.g if for GPyTorch MVN posteriors the covariance is singular, or if a low-rank approximation of the root decomposition is used. See the corresponding GPyTorch PR here: cornellius-gp/gpytorch#1502 Also does some cleanup and moves the `HigherOrderGPPosterior` into the `posteriors` module. Pull Request resolved: #718 Reviewed By: sdaulton Differential Revision: D26608568 Pulled By: Balandat fbshipit-source-id: 8e78d2b662b1ed45cfb1adc2e1517321a31a224f
1 parent f8ff7d2 commit 2fe7d00

File tree

15 files changed

+474
-401
lines changed

15 files changed

+474
-401
lines changed

botorch/acquisition/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,12 @@ def prune_inferior_points(
269269
with torch.no_grad():
270270
posterior = model.posterior(X=X)
271271
if sampler is None:
272-
if posterior.event_shape.numel() > SobolEngine.MAXDIM:
272+
if posterior.base_sample_shape.numel() > SobolEngine.MAXDIM:
273273
if settings.debug.on():
274274
warnings.warn(
275-
f"Sample dimension q*m={posterior.event_shape.numel()} exceeding "
276-
f"Sobol max dimension ({SobolEngine.MAXDIM}). Using iid samples "
277-
"instead.",
275+
f"Sample dimension q*m={posterior.base_sample_shape.numel()} "
276+
f"exceeding Sobol max dimension ({SobolEngine.MAXDIM}). "
277+
"Using iid samples instead.",
278278
SamplingWarning,
279279
)
280280
sampler = IIDNormalSampler(num_samples=num_samples)

botorch/models/higher_order_gp.py

Lines changed: 23 additions & 221 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
from botorch.models.transforms.input import InputTransform
2929
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
3030
from botorch.models.utils import gpt_posterior_settings
31-
from botorch.posteriors import GPyTorchPosterior, TransformedPosterior
31+
from botorch.posteriors import (
32+
GPyTorchPosterior,
33+
HigherOrderGPPosterior,
34+
TransformedPosterior,
35+
)
3236
from gpytorch.constraints import GreaterThan
3337
from gpytorch.distributions import MultivariateNormal
3438
from gpytorch.kernels import Kernel, MaternKernel
@@ -46,7 +50,7 @@
4650
from gpytorch.models import ExactGP
4751
from gpytorch.priors.torch_priors import GammaPrior, MultivariateNormalPrior
4852
from gpytorch.settings import fast_pred_var, skip_posterior_variances
49-
from torch import Size, Tensor
53+
from torch import Tensor
5054
from torch.nn import ModuleList, Parameter, ParameterList
5155

5256

@@ -61,10 +65,13 @@ class FlattenedStandardize(Standardize):
6165
"""
6266

6367
def __init__(
64-
self, output_shape: Size, batch_shape: Size = None, min_stdv: float = 1e-8
68+
self,
69+
output_shape: torch.Size,
70+
batch_shape: torch.Size = None,
71+
min_stdv: float = 1e-8,
6572
):
6673
if batch_shape is None:
67-
batch_shape = Size()
74+
batch_shape = torch.Size()
6875

6976
super(FlattenedStandardize, self).__init__(
7077
m=1, outputs=None, batch_shape=batch_shape, min_stdv=min_stdv
@@ -131,207 +138,6 @@ def untransform_posterior(
131138
)
132139

133140

134-
class HigherOrderGPPosterior(GPyTorchPosterior):
135-
r"""
136-
Posterior class for a Higher order Gaussian process model [Zhe2019hogp]. Extends the
137-
standard GPyTorch posterior class by overwriting the rsample method. The posterior
138-
variance is handled internally by the HigherOrderGP model.
139-
HOGP is a tensorized GP model so the posterior covariance grows to be extremely
140-
large, but is highly structured, which means that we can exploit Kronecker
141-
identities to sample from the posterior using Matheron's rule as described in
142-
[Doucet2010sampl]_. In general, this posterior should ONLY be used for HOGP models
143-
that have highly structured covariances. It should also only be used internally when
144-
called from the HigherOrderGP.posterior(...) method.
145-
"""
146-
147-
def __init__(
148-
self,
149-
mvn: MultivariateNormal,
150-
joint_covariance_matrix: LazyTensor,
151-
train_train_covar: LazyTensor,
152-
test_train_covar: LazyTensor,
153-
train_targets: Tensor,
154-
output_shape: Size,
155-
num_outputs: int,
156-
) -> None:
157-
r"""A Posterior for HigherOrderGP models.
158-
159-
Args:
160-
mvn: Posterior multivariate normal distribution
161-
joint_covariance_matrix: Joint test train covariance matrix over the entire
162-
tensor
163-
train_train_covar: covariance matrix of train points in the data space
164-
test_train_covar: covariance matrix of test x train points in the data space
165-
train_targets: training responses vectorized
166-
output_shape: shape output training responses
167-
num_outputs: batch shaping of model
168-
"""
169-
super().__init__(mvn)
170-
self.joint_covariance_matrix = joint_covariance_matrix
171-
self.train_train_covar = train_train_covar
172-
self.test_train_covar = test_train_covar
173-
self.train_targets = train_targets
174-
self.output_shape = output_shape
175-
self._is_mt = True
176-
self.num_outputs = num_outputs
177-
178-
@property
179-
def event_shape(self):
180-
# overwrites the standard event_shape call to inform samplers that
181-
# n + 2 n_train samples need to be drawn rather than n samples
182-
# TODO: Expose a sample shape property that is independent of the event shape
183-
# and handle those transparently in the samplers.
184-
batch_shape = self.joint_covariance_matrix.shape[:-2]
185-
sampling_shape = (
186-
self.joint_covariance_matrix.shape[-2] + self.train_train_covar.shape[-2]
187-
)
188-
return batch_shape + torch.Size((sampling_shape,))
189-
190-
def _prepare_base_samples(
191-
self, sample_shape: torch.Size, base_samples: Tensor = None
192-
) -> Tensor:
193-
covariance_matrix = self.joint_covariance_matrix
194-
joint_size = covariance_matrix.shape[-1]
195-
batch_shape = covariance_matrix.batch_shape
196-
197-
if base_samples is not None:
198-
if base_samples.shape[: len(sample_shape)] != sample_shape:
199-
raise RuntimeError("sample_shape disagrees with shape of base_samples.")
200-
201-
appended_shape = joint_size + self.train_train_covar.shape[-1]
202-
if appended_shape != base_samples.shape[-1]:
203-
# get base_samples to the correct shape by expanding as sample shape,
204-
# batch shape, then rest of dimensions. We have to add first the sample
205-
# shape, then the batch shape of the model, and then finally the shape
206-
# of the test data points squeezed into a single dimension, accessed
207-
# from the test_train_covar.
208-
base_sample_shapes = (
209-
sample_shape + batch_shape + self.test_train_covar.shape[-2:-1]
210-
)
211-
if base_samples.nelement() == base_sample_shapes.numel():
212-
base_samples = base_samples.reshape(base_sample_shapes)
213-
214-
new_base_samples = torch.randn(
215-
*sample_shape,
216-
*batch_shape,
217-
appended_shape - base_samples.shape[-1],
218-
device=base_samples.device,
219-
dtype=base_samples.dtype,
220-
)
221-
base_samples = torch.cat((base_samples, new_base_samples), dim=-1)
222-
else:
223-
# nuke the base samples if we cannot use them.
224-
base_samples = None
225-
226-
if base_samples is None:
227-
# TODO: Allow qMC sampling
228-
base_samples = torch.randn(
229-
*sample_shape,
230-
*batch_shape,
231-
joint_size,
232-
device=covariance_matrix.device,
233-
dtype=covariance_matrix.dtype,
234-
)
235-
236-
noise_base_samples = torch.randn(
237-
*sample_shape,
238-
*batch_shape,
239-
self.train_train_covar.shape[-1],
240-
device=covariance_matrix.device,
241-
dtype=covariance_matrix.dtype,
242-
)
243-
else:
244-
# finally split up the base samples
245-
noise_base_samples = base_samples[..., joint_size:]
246-
base_samples = base_samples[..., :joint_size]
247-
248-
perm_list = [*range(1, base_samples.ndim), 0]
249-
return base_samples.permute(*perm_list), noise_base_samples.permute(*perm_list)
250-
251-
def rsample(
252-
self,
253-
sample_shape: Optional[torch.Size] = None,
254-
base_samples: Optional[Tensor] = None,
255-
) -> Tensor:
256-
r"""Sample from the posterior (with gradients).
257-
258-
As the posterior covariance is difficult to draw from in this model,
259-
we implement Matheron's rule as described in [Doucet2010sampl]. This may not
260-
work entirely correctly for deterministic base samples unless base samples
261-
are provided that are of shape `n + 2 * n_train` because the sampling method
262-
draws `2 * n_train` samples as well as the standard `n`.
263-
samples.
264-
265-
Args:
266-
sample_shape: A `torch.Size` object specifying the sample shape. To
267-
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
268-
of `n` samples each, set to `torch.Size([b, n])`.
269-
base_samples: An (optional) Tensor of `N(0, I)` base samples of
270-
appropriate dimension, typically obtained from a `Sampler`.
271-
This is used for deterministic optimization.
272-
273-
Returns:
274-
A `sample_shape x event_shape`-dim Tensor of samples from the posterior.
275-
"""
276-
if sample_shape is None:
277-
sample_shape = torch.Size([1])
278-
279-
base_samples, noise_base_samples = self._prepare_base_samples(
280-
sample_shape, base_samples
281-
)
282-
283-
# base samples now have trailing sample dimension
284-
covariance_matrix = self.joint_covariance_matrix
285-
covar_root = covariance_matrix.root_decomposition().root
286-
samples = covar_root.matmul(base_samples)
287-
288-
# now pluck out Y_x and X_x
289-
noiseless_train_marginal_samples = samples[
290-
..., : self.train_train_covar.shape[-1], :
291-
]
292-
test_marginal_samples = samples[..., self.train_train_covar.shape[-1] :, :]
293-
# we need to add noise to the train_joint_samples
294-
# THIS ASSUMES CONSTANT NOISE
295-
noise_std = self.train_train_covar.lazy_tensors[1]._diag[..., 0] ** 0.5
296-
# TODO: cleanup the reshaping here
297-
# expands the noise to allow broadcasting against the noise base samples
298-
# reshape_as or view_as don't work here because we need to expand to
299-
# broadcast against `samples x batch_shape x output_shape` while noise_std
300-
# is `batch_shape x 1`.
301-
if self.num_outputs > 1 or noise_std.ndim > 1:
302-
ntms_dims = [
303-
i == noise_std.shape[0] for i in noiseless_train_marginal_samples.shape
304-
]
305-
for matched in ntms_dims:
306-
if not matched:
307-
noise_std = noise_std.unsqueeze(-1)
308-
309-
# we need to add noise into the noiseless samples
310-
noise_marginal_samples = noise_std * noise_base_samples
311-
312-
train_marginal_samples = (
313-
noiseless_train_marginal_samples + noise_marginal_samples
314-
)
315-
316-
# compute y - Y_x
317-
train_rhs = self.train_targets - train_marginal_samples
318-
319-
# K_{train, train}^{-1} (y - Y_x)
320-
# internally, this solve is done using Kronecker algebra and is fast.
321-
kinv_rhs = self.train_train_covar.inv_matmul(train_rhs)
322-
# multiply by cross-covariance
323-
test_updated_samples = self.test_train_covar.matmul(kinv_rhs)
324-
325-
# add samples
326-
test_cond_samples = test_marginal_samples + test_updated_samples
327-
test_cond_samples = test_cond_samples.permute(
328-
test_cond_samples.ndim - 1, *range(0, test_cond_samples.ndim - 1)
329-
)
330-
331-
# reshape samples to be the actual size of the train targets
332-
return test_cond_samples.reshape(*sample_shape, *self.output_shape)
333-
334-
335141
class HigherOrderGP(BatchedMultiOutputGPyTorchModel, ExactGP):
336142
r"""
337143
A Higher order Gaussian process model (HOGP) (predictions are matrices/tensors) as
@@ -564,19 +370,20 @@ def condition_on_observations(
564370
self, X: Tensor, Y: Tensor, **kwargs: Any
565371
) -> HigherOrderGP:
566372
r"""Condition the model on new observations.
373+
567374
Args:
568375
X: A `batch_shape x n' x d`-dim Tensor, where `d` is the dimension of
569-
the feature space, `m` is the number of points per batch, and
570-
`batch_shape` is the batch shape (must be compatible with the
571-
batch shape of the model).
572-
376+
the feature space, `m` is the number of points per batch, and
377+
`batch_shape` is the batch shape (must be compatible with the
378+
batch shape of the model).
573379
Y: A `batch_shape' x n' x m_d`-dim Tensor, where `m_d` is the shaping
574-
of the model outputs, `n'` is the number of points per batch, and
575-
`batch_shape'` is the batch shape of the observations.
576-
`batch_shape'` must be broadcastable to `batch_shape` using
577-
standard broadcasting semantics. If `Y` has fewer batch dimensions
578-
than `X`, its is assumed that the missing batch dimensions are
579-
the same for all `Y`.
380+
of the model outputs, `n'` is the number of points per batch, and
381+
`batch_shape'` is the batch shape of the observations.
382+
`batch_shape'` must be broadcastable to `batch_shape` using
383+
standard broadcasting semantics. If `Y` has fewer batch dimensions
384+
than `X`, its is assumed that the missing batch dimensions are
385+
the same for all `Y`.
386+
580387
Returns:
581388
A `BatchedMultiOutputGPyTorchModel` object of the same type with
582389
`n + n'` training examples, representing the original model
@@ -693,12 +500,7 @@ def posterior(
693500
train_train_covar=train_train_covar,
694501
test_train_covar=test_train_covar,
695502
joint_covariance_matrix=full_covar.clone(),
696-
output_shape=Size(
697-
(
698-
*X.shape[:-1],
699-
*self.target_shape,
700-
)
701-
),
503+
output_shape=X.shape[:-1] + self.target_shape,
702504
num_outputs=self._num_outputs,
703505
)
704506
if hasattr(self, "outcome_transform"):

botorch/models/transforms/outcome.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
r"""
8+
Outcome transformations for automatically transforming and un-transforming
9+
model outputs. Outcome transformations are typically part of a Model and
10+
applied (i) within the model constructor to transform the train observations
11+
to the model space, and (ii) in the `Model.posterior` call to untransform
12+
the model posterior back to the original space.
13+
"""
14+
715
from __future__ import annotations
816

917
from abc import ABC, abstractmethod

botorch/posteriors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
from botorch.posteriors.deterministic import DeterministicPosterior
88
from botorch.posteriors.gpytorch import GPyTorchPosterior
9+
from botorch.posteriors.higher_order import HigherOrderGPPosterior
910
from botorch.posteriors.posterior import Posterior
1011
from botorch.posteriors.transformed import TransformedPosterior
1112

1213

1314
__all__ = [
1415
"DeterministicPosterior",
1516
"GPyTorchPosterior",
17+
"HigherOrderGPPosterior",
1618
"Posterior",
1719
"TransformedPosterior",
1820
]

botorch/posteriors/gpytorch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ def __init__(self, mvn: MultivariateNormal) -> None:
3535
self.mvn = mvn
3636
self._is_mt = isinstance(mvn, MultitaskMultivariateNormal)
3737

38+
@property
39+
def base_sample_shape(self) -> torch.Size:
40+
r"""The shape of a base sample used for constructing posterior samples."""
41+
shape = self.mvn.batch_shape + self.mvn.base_sample_shape
42+
if not self._is_mt:
43+
shape += torch.Size([1])
44+
return shape
45+
3846
@property
3947
def device(self) -> torch.device:
4048
r"""The torch device of the posterior."""

0 commit comments

Comments
 (0)