Skip to content

Commit b3d3074

Browse files
James Wilsonfacebook-github-bot
authored andcommitted
sampling.pathwise (#1463)
Summary: Pull Request resolved: #1463 This PR contains code for efficiently sampling functions from (approximate) GP priors and posteriors. The functionality introduced here is largely exposed via three high-level methods: - `gen_kernel_features`: Generates a feature map that represents (or approximates) a kernel. - `draw_kernel_feature_paths`: Draws functions from a Bayesian-linear-model-based approximation to a GP prior. By default, uses random Fourier features (RFF) to represent stationary priors. - `draw_matheron_paths`: Generates draws from an approximate GP posterior using Matheron's rule. By default, this method combines draws from an RFF-based approximate prior with exact Gaussian updates. For details, see [[Wilson et al., 2020]](https://arxiv.org/abs/2002.09309#) and [[Wilson et al., 2021]](https://arxiv.org/pdf/2011.04026.pdf). Please let us know if you run into issues. As always, contributions are welcomed. Reviewed By: saitcakmak Differential Revision: D40662482 fbshipit-source-id: 95761c3cd830afba4f7a11db30e5203199b595ce
1 parent 150d673 commit b3d3074

31 files changed

+2504
-31
lines changed

botorch/optim/closures/model_closures.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
1313

1414
from botorch.optim.closures.core import ForwardBackwardClosure
15-
from botorch.optim.utils import TNone
1615
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
16+
from botorch.utils.types import NoneType
1717
from gpytorch.mlls import (
1818
ExactMarginalLogLikelihood,
1919
MarginalLogLikelihood,
@@ -151,9 +151,9 @@ def closure(**kwargs: Any) -> Tensor:
151151
return closure
152152

153153

154-
@GetLossClosure.register(MarginalLogLikelihood, object, object, TNone)
154+
@GetLossClosure.register(MarginalLogLikelihood, object, object, NoneType)
155155
def _get_loss_closure_fallback_internal(
156-
mll: MarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any
156+
mll: MarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
157157
) -> Callable[[], Tensor]:
158158
r"""Fallback loss closure with internally managed data."""
159159

@@ -165,9 +165,9 @@ def closure(**kwargs: Any) -> Tensor:
165165
return closure
166166

167167

168-
@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, TNone)
168+
@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, NoneType)
169169
def _get_loss_closure_exact_internal(
170-
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any
170+
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
171171
) -> Callable[[], Tensor]:
172172
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""
173173

@@ -181,9 +181,9 @@ def closure(**kwargs: Any) -> Tensor:
181181
return closure
182182

183183

184-
@GetLossClosure.register(SumMarginalLogLikelihood, object, object, TNone)
184+
@GetLossClosure.register(SumMarginalLogLikelihood, object, object, NoneType)
185185
def _get_loss_closure_sum_internal(
186-
mll: SumMarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any
186+
mll: SumMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
187187
) -> Callable[[], Tensor]:
188188
r"""SumMarginalLogLikelihood loss closure with internally managed data."""
189189

botorch/optim/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@
4343
from botorch.optim.utils import (
4444
_filter_kwargs,
4545
_get_extra_mll_args,
46-
DEFAULT,
4746
get_name_filter,
4847
get_parameters_and_bounds,
4948
TorchAttr,
5049
)
5150
from botorch.optim.utils.model_utils import get_parameters
51+
from botorch.utils.types import DEFAULT
5252
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
5353
from gpytorch.settings import fast_computations
5454
from numpy import ndarray

botorch/optim/utils/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
_filter_kwargs,
1414
_handle_numerical_errors,
1515
_warning_handler_template,
16-
DEFAULT,
17-
TNone,
1816
)
1917
from botorch.optim.utils.model_utils import (
2018
_get_extra_mll_args,
@@ -40,7 +38,6 @@
4038
"_warning_handler_template",
4139
"as_ndarray",
4240
"columnwise_clamp",
43-
"DEFAULT",
4441
"fix_features",
4542
"get_name_filter",
4643
"get_bounds_as_ndarray",
@@ -53,5 +50,4 @@
5350
"sample_all_priors",
5451
"set_tensors_from_ndarray_1d",
5552
"TorchAttr",
56-
"TNone",
5753
]

botorch/optim/utils/common.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,6 @@
1616
import numpy as np
1717
from linear_operator.utils.errors import NanError, NotPSDError
1818

19-
TNone = type(None)
20-
21-
22-
class _TDefault:
23-
pass
24-
25-
26-
DEFAULT = _TDefault()
27-
2819

2920
def _filter_kwargs(function: Callable, **kwargs: Any) -> Any:
3021
r"""Filter out kwargs that are not applicable for a given function.

botorch/optim/utils/numpy_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import numpy as np
1616
import torch
17-
from botorch.optim.utils.common import TNone
17+
from botorch.utils.types import NoneType
1818
from numpy import ndarray
1919
from torch import Tensor
2020

@@ -137,7 +137,9 @@ def set_tensors_from_ndarray_1d(
137137

138138
def get_bounds_as_ndarray(
139139
parameters: Dict[str, Tensor],
140-
bounds: Dict[str, Tuple[Union[float, Tensor, TNone], Union[float, Tensor, TNone]]],
140+
bounds: Dict[
141+
str, Tuple[Union[float, Tensor, NoneType], Union[float, Tensor, NoneType]]
142+
],
141143
) -> Optional[np.ndarray]:
142144
r"""Helper method for converting bounds into an ndarray.
143145
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from botorch.sampling.pathwise.features import (
9+
gen_kernel_features,
10+
KernelEvaluationMap,
11+
KernelFeatureMap,
12+
)
13+
from botorch.sampling.pathwise.paths import (
14+
GeneralizedLinearPath,
15+
PathDict,
16+
PathList,
17+
SamplePath,
18+
)
19+
from botorch.sampling.pathwise.posterior_samplers import (
20+
draw_matheron_paths,
21+
MatheronPath,
22+
)
23+
from botorch.sampling.pathwise.prior_samplers import draw_kernel_feature_paths
24+
from botorch.sampling.pathwise.update_strategies import gaussian_update
25+
26+
27+
__all__ = [
28+
"draw_matheron_paths",
29+
"draw_kernel_feature_paths",
30+
"gen_kernel_features",
31+
"gaussian_update",
32+
"GeneralizedLinearPath",
33+
"KernelEvaluationMap",
34+
"KernelFeatureMap",
35+
"MatheronPath",
36+
"SamplePath",
37+
"PathDict",
38+
"PathList",
39+
]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from botorch.sampling.pathwise.features.generators import gen_kernel_features
9+
from botorch.sampling.pathwise.features.maps import (
10+
FeatureMap,
11+
KernelEvaluationMap,
12+
KernelFeatureMap,
13+
)
14+
15+
__all__ = [
16+
"FeatureMap",
17+
"gen_kernel_features",
18+
"KernelEvaluationMap",
19+
"KernelFeatureMap",
20+
]
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
.. [rahimi2007random]
9+
A. Rahimi and B. Recht. Random features for large-scale kernel machines.
10+
Advances in Neural Information Processing Systems 20 (2007).
11+
12+
.. [sutherland2015error]
13+
D. J. Sutherland and J. Schneider. On the error of random Fourier features.
14+
arXiv preprint arXiv:1506.02785 (2015).
15+
"""
16+
17+
from __future__ import annotations
18+
19+
from typing import Any, Callable
20+
21+
import torch
22+
from botorch.exceptions.errors import UnsupportedError
23+
from botorch.sampling.pathwise.features.maps import KernelFeatureMap
24+
from botorch.sampling.pathwise.utils import (
25+
ChainedTransform,
26+
FeatureSelector,
27+
InverseLengthscaleTransform,
28+
OutputscaleTransform,
29+
SineCosineTransform,
30+
)
31+
from botorch.utils.dispatcher import Dispatcher
32+
from botorch.utils.sampling import draw_sobol_normal_samples
33+
from gpytorch import kernels
34+
from gpytorch.kernels.kernel import Kernel
35+
from torch import Size, Tensor
36+
from torch.distributions import Gamma
37+
38+
TKernelFeatureMapGenerator = Callable[[Kernel, int, int], KernelFeatureMap]
39+
GenKernelFeatures = Dispatcher("gen_kernel_features")
40+
41+
42+
def gen_kernel_features(
43+
kernel: kernels.Kernel,
44+
num_inputs: int,
45+
num_outputs: int,
46+
**kwargs: Any,
47+
) -> KernelFeatureMap:
48+
r"""Generates a feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^{n}` such that
49+
:math:`k(x, x') ≈ \phi(x)^{T} \phi(x')`. For stationary kernels :math:`k`, defaults
50+
to the method of random Fourier features. For more details, see [rahimi2007random]_
51+
and [sutherland2015error]_.
52+
53+
Args:
54+
kernel: The kernel :math:`k` to be represented via a finite-dim basis.
55+
num_inputs: The number of input features.
56+
num_outputs: The number of kernel features.
57+
"""
58+
return GenKernelFeatures(
59+
kernel,
60+
num_inputs=num_inputs,
61+
num_outputs=num_outputs,
62+
**kwargs,
63+
)
64+
65+
66+
def _gen_fourier_features(
67+
kernel: kernels.Kernel,
68+
weight_generator: Callable[[Size], Tensor],
69+
num_inputs: int,
70+
num_outputs: int,
71+
) -> KernelFeatureMap:
72+
r"""Generate a feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^{2l}` that
73+
approximates a stationary kernel so that :math:`k(x, x') ≈ \phi(x)^\top \phi(x')`.
74+
75+
Following [sutherland2015error]_, we represent complex exponentials by pairs of
76+
basis functions :math:`\phi_{i}(x) = \sin(x^\top w_{i})` and
77+
:math:`\phi_{i + l} = \cos(x^\top w_{i}).
78+
79+
Args:
80+
kernel: A stationary kernel :math:`k(x, x') = k(x - x')`.
81+
weight_generator: A callable used to generate weight vectors :math:`w`.
82+
num_inputs: The number of input features.
83+
num_outputs: The number of Fourier features.
84+
"""
85+
if num_outputs % 2:
86+
raise UnsupportedError(
87+
f"Expected an even number of output features, but received {num_outputs=}."
88+
)
89+
90+
input_transform = InverseLengthscaleTransform(kernel)
91+
if kernel.active_dims is not None:
92+
num_inputs = len(kernel.active_dims)
93+
input_transform = ChainedTransform(
94+
input_transform, FeatureSelector(indices=kernel.active_dims)
95+
)
96+
97+
weight = weight_generator(
98+
Size([kernel.batch_shape.numel() * num_outputs // 2, num_inputs])
99+
).reshape(*kernel.batch_shape, num_outputs // 2, num_inputs)
100+
101+
output_transform = SineCosineTransform(
102+
torch.tensor((2 / num_outputs) ** 0.5, device=kernel.device, dtype=kernel.dtype)
103+
)
104+
return KernelFeatureMap(
105+
kernel=kernel,
106+
weight=weight,
107+
input_transform=input_transform,
108+
output_transform=output_transform,
109+
)
110+
111+
112+
@GenKernelFeatures.register(kernels.RBFKernel)
113+
def _gen_kernel_features_rbf(
114+
kernel: kernels.RBFKernel,
115+
*,
116+
num_inputs: int,
117+
num_outputs: int,
118+
) -> KernelFeatureMap:
119+
def _weight_generator(shape: Size) -> Tensor:
120+
try:
121+
n, d = shape
122+
except ValueError:
123+
raise UnsupportedError(
124+
f"Expected `shape` to be 2-dimensional, but {len(shape)=}."
125+
)
126+
127+
return draw_sobol_normal_samples(
128+
n=n,
129+
d=d,
130+
device=kernel.lengthscale.device,
131+
dtype=kernel.lengthscale.dtype,
132+
)
133+
134+
return _gen_fourier_features(
135+
kernel=kernel,
136+
weight_generator=_weight_generator,
137+
num_inputs=num_inputs,
138+
num_outputs=num_outputs,
139+
)
140+
141+
142+
@GenKernelFeatures.register(kernels.MaternKernel)
143+
def _gen_kernel_features_matern(
144+
kernel: kernels.MaternKernel,
145+
*,
146+
num_inputs: int,
147+
num_outputs: int,
148+
) -> KernelFeatureMap:
149+
def _weight_generator(shape: Size) -> Tensor:
150+
try:
151+
n, d = shape
152+
except ValueError:
153+
raise UnsupportedError(
154+
f"Expected `shape` to be 2-dimensional, but {len(shape)=}."
155+
)
156+
157+
dtype = kernel.lengthscale.dtype
158+
device = kernel.lengthscale.device
159+
nu = torch.tensor(kernel.nu, device=device, dtype=dtype)
160+
normals = draw_sobol_normal_samples(n=n, d=d, device=device, dtype=dtype)
161+
return Gamma(nu, nu).rsample((n, 1)).rsqrt() * normals
162+
163+
return _gen_fourier_features(
164+
kernel=kernel,
165+
weight_generator=_weight_generator,
166+
num_inputs=num_inputs,
167+
num_outputs=num_outputs,
168+
)
169+
170+
171+
@GenKernelFeatures.register(kernels.ScaleKernel)
172+
def _gen_kernel_features_scale(
173+
kernel: kernels.ScaleKernel,
174+
*,
175+
num_inputs: int,
176+
num_outputs: int,
177+
) -> KernelFeatureMap:
178+
active_dims = kernel.active_dims
179+
feature_map = gen_kernel_features(
180+
kernel.base_kernel,
181+
num_inputs=num_inputs if active_dims is None else len(active_dims),
182+
num_outputs=num_outputs,
183+
)
184+
185+
if active_dims is not None and active_dims is not kernel.base_kernel.active_dims:
186+
feature_map.input_transform = ChainedTransform(
187+
feature_map.input_transform, FeatureSelector(indices=active_dims)
188+
)
189+
190+
feature_map.output_transform = ChainedTransform(
191+
OutputscaleTransform(kernel), feature_map.output_transform
192+
)
193+
return feature_map

0 commit comments

Comments
 (0)