Skip to content

Commit 2771266

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add a helper for producing a DeterministicModel using a Matheron path (#2435)
Summary: Pull Request resolved: #2435 Out of the box, pathwise sampling code does not produce outputs that conform to BoTorch modeling conventions. This adds a helper that produces a Matheron path and wraps it in a `GenericDeterministicModel` using a helper that reshapes the path outputs to mimic the expected behavior from the model's posterior. This will offer a straightforward replacement of `get_gp_samples`, which is being deprecated. Reviewed By: esantorella Differential Revision: D59941730 fbshipit-source-id: 6d5fe5cc2a48257bb564719f6dd63567c2d9b466
1 parent d14977c commit 2771266

File tree

3 files changed

+129
-3
lines changed

3 files changed

+129
-3
lines changed

botorch/sampling/pathwise/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from botorch.sampling.pathwise.posterior_samplers import (
2020
draw_matheron_paths,
21+
get_matheron_path_model,
2122
MatheronPath,
2223
)
2324
from botorch.sampling.pathwise.prior_samplers import draw_kernel_feature_paths
@@ -28,6 +29,7 @@
2829
"draw_matheron_paths",
2930
"draw_kernel_feature_paths",
3031
"gen_kernel_features",
32+
"get_matheron_path_model",
3133
"gaussian_update",
3234
"GeneralizedLinearPath",
3335
"KernelEvaluationMap",

botorch/sampling/pathwise/posterior_samplers.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919

2020
from typing import Optional, Union
2121

22+
import torch
23+
from botorch.exceptions.errors import UnsupportedError
2224
from botorch.models.approximate_gp import ApproximateGPyTorchModel
25+
from botorch.models.deterministic import GenericDeterministicModel
26+
from botorch.models.model import ModelList
2327
from botorch.models.model_list_gp_regression import ModelListGP
2428
from botorch.sampling.pathwise.paths import PathDict, PathList, SamplePath
2529
from botorch.sampling.pathwise.prior_samplers import (
@@ -36,8 +40,9 @@
3640
)
3741
from botorch.utils.context_managers import delattr_ctx
3842
from botorch.utils.dispatcher import Dispatcher
43+
from botorch.utils.transforms import is_ensemble
3944
from gpytorch.models import ApproximateGP, ExactGP, GP
40-
from torch import Size
45+
from torch import Size, Tensor
4146

4247
DrawMatheronPaths = Dispatcher("draw_matheron_paths")
4348

@@ -83,13 +88,69 @@ def __init__(
8388
)
8489

8590

91+
def get_matheron_path_model(
92+
model: GP, sample_shape: Optional[Size] = None
93+
) -> GenericDeterministicModel:
94+
r"""Generates a deterministic model using a single Matheron path drawn
95+
from the model's posterior.
96+
97+
The deterministic model evalutes the output of `draw_matheron_paths`,
98+
and reshapes it to mimic the output behavior of the model's posterior.
99+
100+
Args:
101+
model: The model whose posterior is to be sampled.
102+
sample_shape: The shape of the sample paths to be drawn, if an ensemble
103+
of sample paths is desired. If this is specified, the resulting
104+
deterministic model will behave as if the `sample_shape` is prepended
105+
to the `batch_shape` of the model. The inputs used to evaluate the model
106+
must be adjusted to match.
107+
108+
Returns:
109+
A deterministic model that evaluates the Matheron path.
110+
"""
111+
sample_shape = Size() if sample_shape is None else sample_shape
112+
path = draw_matheron_paths(model, sample_shape=sample_shape)
113+
num_outputs = model.num_outputs
114+
if isinstance(model, ModelList) and len(model.models) != num_outputs:
115+
raise UnsupportedError("A model-list of multi-output models is not supported.")
116+
117+
def f(X: Tensor) -> Tensor:
118+
r"""Reshapes the path evaluations to bring the output dimension to the end.
119+
120+
Args:
121+
X: The input tensor of shape `batch_shape x q x d`.
122+
If the model is batched, `batch_shape` must be broadcastable to
123+
the model batch shape.
124+
125+
Returns:
126+
The output tensor of shape `batch_shape x q x m`.
127+
"""
128+
if num_outputs == 1:
129+
# For single-output, we lack the output dimension. Add one.
130+
res = path(X).unsqueeze(-1)
131+
elif isinstance(model, ModelList):
132+
# For model list, path evaluates to a list of tensors. Stack them.
133+
res = torch.stack(path(X), dim=-1)
134+
else:
135+
# For multi-output, path expects inputs broadcastable to
136+
# `model._aug_batch_shape x q x d` and returns outputs of shape
137+
# `model._aug_batch_shape x q`. Augmented batch shape includes the
138+
# `m` dimension, so we will unsqueeze that and transpose after.
139+
res = path(X.unsqueeze(-3)).transpose(-1, -2)
140+
return res
141+
142+
path_model = GenericDeterministicModel(f=f, num_outputs=num_outputs)
143+
path_model._is_ensemble = is_ensemble(model) or len(sample_shape) > 0
144+
return path_model
145+
146+
86147
def draw_matheron_paths(
87148
model: GP,
88149
sample_shape: Size,
89150
prior_sampler: TPathwisePriorSampler = draw_kernel_feature_paths,
90151
update_strategy: TPathwiseUpdate = gaussian_update,
91152
) -> MatheronPath:
92-
r"""Generates function draws from (an approximate) Gaussian process prior.
153+
r"""Generates function draws from (an approximate) Gaussian process posterior.
93154
94155
When evaluted, sample paths produced by this method return Tensors with dimensions
95156
`sample_dims x batch_dims x [joint_dim]`, where `joint_dim` denotes the penultimate

test/sampling/pathwise/test_posterior_samplers.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
from __future__ import annotations
88

99
from copy import deepcopy
10+
from typing import Any, Dict
1011

1112
import torch
13+
from botorch.exceptions.errors import UnsupportedError
1214
from botorch.models import ModelListGP, SingleTaskGP, SingleTaskVariationalGP
15+
from botorch.models.deterministic import GenericDeterministicModel
1316
from botorch.models.transforms.input import Normalize
1417
from botorch.models.transforms.outcome import Standardize
1518
from botorch.sampling.pathwise import draw_matheron_paths, MatheronPath, PathList
19+
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
1620
from botorch.sampling.pathwise.utils import get_train_inputs
1721
from botorch.utils.test_helpers import get_sample_moments, standardize_moments
1822
from botorch.utils.testing import BotorchTestCase
@@ -24,7 +28,7 @@
2428
class TestPosteriorSamplers(BotorchTestCase):
2529
def setUp(self, suppress_input_warnings: bool = True) -> None:
2630
super().setUp(suppress_input_warnings=suppress_input_warnings)
27-
tkwargs = {"device": self.device, "dtype": torch.float64}
31+
tkwargs: Dict[str, Any] = {"device": self.device, "dtype": torch.float64}
2832
torch.manual_seed(0)
2933

3034
base = MaternKernel(nu=2.5, ard_num_dims=2, batch_shape=Size([]))
@@ -67,6 +71,8 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
6771
outcome_transform=outcome_transform,
6872
).to(**tkwargs)
6973

74+
self.tkwargs = tkwargs
75+
7076
def test_draw_matheron_paths(self):
7177
for seed, model in enumerate(
7278
(self.inferred_noise_gp, self.observed_noise_gp, self.variational_gp)
@@ -122,3 +128,60 @@ def _test_draw_matheron_paths(self, model, paths, sample_shape, atol=3):
122128
tol = atol * (num_features**-0.5 + sample_shape.numel() ** -0.5)
123129
for exact, estimate in zip(exact_moments, sample_moments):
124130
self.assertTrue(exact.allclose(estimate, atol=tol, rtol=0))
131+
132+
def test_get_matheron_path_model(self) -> None:
133+
model_list = ModelListGP(self.inferred_noise_gp, self.observed_noise_gp)
134+
moo_model = SingleTaskGP(
135+
train_X=torch.rand(5, 2, **self.tkwargs),
136+
train_Y=torch.rand(5, 2, **self.tkwargs),
137+
)
138+
139+
test_X = torch.rand(5, 2, **self.tkwargs)
140+
batch_test_X = torch.rand(3, 5, 2, **self.tkwargs)
141+
sample_shape = Size([2])
142+
sample_shape_X = torch.rand(3, 2, 5, 2, **self.tkwargs)
143+
for model in (self.inferred_noise_gp, moo_model, model_list):
144+
path_model = get_matheron_path_model(model=model)
145+
self.assertFalse(path_model._is_ensemble)
146+
self.assertIsInstance(path_model, GenericDeterministicModel)
147+
for X in (test_X, batch_test_X):
148+
self.assertEqual(
149+
model.posterior(X).mean.shape, path_model.posterior(X).mean.shape
150+
)
151+
path_model = get_matheron_path_model(model=model, sample_shape=sample_shape)
152+
self.assertTrue(path_model._is_ensemble)
153+
self.assertEqual(
154+
path_model.posterior(sample_shape_X).mean.shape,
155+
sample_shape_X.shape[:-1] + Size([model.num_outputs]),
156+
)
157+
158+
with self.assertRaisesRegex(
159+
UnsupportedError, "A model-list of multi-output models is not supported."
160+
):
161+
get_matheron_path_model(
162+
model=ModelListGP(self.inferred_noise_gp, moo_model)
163+
)
164+
165+
def test_get_matheron_path_model_batched(self) -> None:
166+
model = SingleTaskGP(
167+
train_X=torch.rand(4, 5, 2, **self.tkwargs),
168+
train_Y=torch.rand(4, 5, 2, **self.tkwargs),
169+
)
170+
model._is_ensemble = True
171+
path_model = get_matheron_path_model(model=model)
172+
self.assertTrue(path_model._is_ensemble)
173+
test_X = torch.rand(5, 2, **self.tkwargs)
174+
# This mimics the behavior of the acquisition functions unsqueezing the
175+
# model batch dimension for ensemble models.
176+
batch_test_X = torch.rand(3, 1, 5, 2, **self.tkwargs)
177+
# Explicitly matching X for completeness.
178+
complete_test_X = torch.rand(3, 4, 5, 2, **self.tkwargs)
179+
for X in (test_X, batch_test_X, complete_test_X):
180+
self.assertEqual(
181+
model.posterior(X).mean.shape, path_model.posterior(X).mean.shape
182+
)
183+
184+
# Test with sample_shape.
185+
path_model = get_matheron_path_model(model=model, sample_shape=Size([2, 6]))
186+
test_X = torch.rand(3, 2, 6, 4, 5, 2, **self.tkwargs)
187+
self.assertEqual(path_model.posterior(test_X).mean.shape, test_X.shape)

0 commit comments

Comments
 (0)