Skip to content

Commit d935d10

Browse files
henrymossfacebook-github-bot
authored andcommitted
Inducing Point Allocators for Sparse GPs (#240)
Summary: X-link: facebookresearch/aepsych#240 ## Motivation As requested by eytan, I have had a go at implementing BO-specific inducing point allocation (IPA) strategies from my new paper (https://arxiv.org/abs/2301.10123). These allow you to build sparse Gaussian processes that are somewhat customized to the particular task at hand, be it BO, active learning e.t.c. 1. I have rewritten part of your approximate GP code to allow the specification of custom IPA strategies. 2. BoTorch's existing behaviour (using the greedy variance reduction of Burt et al (2020)) is a special case of the new functionality and remains the default. 3. I have included one of the new IPAs from my paper to help demonstrate how custom IPAs can be defined. The one I chose is quite complicated (requiring access to model from the previous BO step) and so being able to implement it was a good test case to check that this code for custom IPA initialisation isnt stupid. 4. There ended up being a lot of IPA related code, so I moved it all into a utility file, with its own testing file. Pull Request resolved: #1652 Test Plan: I have added loads of tests in test_inducing_point_allocators.py. I also added some extra tests higher up in the code in test_approximate_gp.py. There is perhaps a little bit of redundancy between the two sets of tests, but they are testing different interfaces. ## Related PRs I will write a follow-up PR with a notebook demonstrating this new inducing point allocation functionality and showing how a user can define their own IPA. Reviewed By: saitcakmak Differential Revision: D43556981 Pulled By: esantorella fbshipit-source-id: 1fd989c00e6dabf302beec5f4bde0279ccaf158b
1 parent 5df2fab commit d935d10

File tree

6 files changed

+750
-149
lines changed

6 files changed

+750
-149
lines changed

botorch/acquisition/max_value_entropy_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
1414
.. [Moss2021gibbon]
1515
Moss, H. B., et al.,
16-
GIBBON: General-purpose Information-Based Bayesian OptimisatioN
17-
arXiv:2102.03324, 2021
16+
GIBBON: General-purpose Information-Based Bayesian OptimisatioN.
17+
Journal of Machine Learning Research, 2021.
1818
1919
.. [Takeno2020mfmves]
2020
S. Takeno, H. Fukuoka, Y. Tsukada, T. Koyama, M. Shiga, I. Takeuchi,

botorch/models/approximate_gp.py

Lines changed: 52 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,36 @@
1313
Journal of Machine Learning Research, 2020,
1414
http://jmlr.org/papers/v21/19-1015.html.
1515
16-
.. [chen2018dpp]
17-
Laming Chen and Guoxin Zhang and Hanning Zhou, Fast greedy MAP inference
18-
for determinantal point process to improve recommendation diversity,
19-
Proceedings of the 32nd International Conference on Neural Information
20-
Processing Systems, 2018, https://arxiv.org/abs/1709.05135.
21-
2216
.. [hensman2013svgp]
2317
James Hensman and Nicolo Fusi and Neil D. Lawrence, Gaussian Processes
2418
for Big Data, Proceedings of the 29th Conference on Uncertainty in
2519
Artificial Intelligence, 2013, https://arxiv.org/abs/1309.6835.
2620
21+
.. [moss2023ipa]
22+
Henry B. Moss and Sebastian W. Ober and Victor Picheny,
23+
Inducing Point Allocation for Sparse Gaussian Processes
24+
in High-Throughput Bayesian Optimization,Proceedings of
25+
the 25th International Conference on Artificial Intelligence
26+
and Statistics, 2023, https://arxiv.org/pdf/2301.10123.pdf.
27+
2728
"""
2829

2930
from __future__ import annotations
3031

3132
import copy
33+
import warnings
34+
3235
from typing import Optional, Type, Union
3336

3437
import torch
3538
from botorch.models.gpytorch import GPyTorchModel
3639
from botorch.models.transforms.input import InputTransform
3740
from botorch.models.transforms.outcome import OutcomeTransform
3841
from botorch.models.utils import validate_input_scaling
42+
from botorch.models.utils.inducing_point_allocators import (
43+
GreedyVarianceReduction,
44+
InducingPointAllocator,
45+
)
3946
from botorch.posteriors.gpytorch import GPyTorchPosterior
4047
from gpytorch.constraints import GreaterThan
4148
from gpytorch.distributions import MultivariateNormal
@@ -47,7 +54,6 @@
4754
)
4855
from gpytorch.means import ConstantMean, Mean
4956
from gpytorch.models import ApproximateGP
50-
from gpytorch.module import Module
5157
from gpytorch.priors import GammaPrior
5258
from gpytorch.utils.memoize import clear_cache_hook
5359
from gpytorch.variational import (
@@ -57,12 +63,10 @@
5763
IndependentMultitaskVariationalStrategy,
5864
VariationalStrategy,
5965
)
60-
from linear_operator.operators import LinearOperator
6166
from torch import Tensor
6267

6368

6469
MIN_INFERRED_NOISE_LEVEL = 1e-4
65-
NEG_INF = -(torch.tensor(float("inf")))
6670

6771

6872
class ApproximateGPyTorchModel(GPyTorchModel):
@@ -148,7 +152,8 @@ class _SingleTaskVariationalGP(ApproximateGP):
148152
Base class wrapper for a stochastic variational Gaussian Process (SVGP)
149153
model [hensman2013svgp]_.
150154
151-
Uses pivoted Cholesky initialization for the inducing points.
155+
Uses by default pivoted Cholesky initialization for allocating inducing points,
156+
however, custom inducing point allocators can be provided.
152157
"""
153158

154159
def __init__(
@@ -162,6 +167,7 @@ def __init__(
162167
variational_distribution: Optional[_VariationalDistribution] = None,
163168
variational_strategy: Type[_VariationalStrategy] = VariationalStrategy,
164169
inducing_points: Optional[Union[Tensor, int]] = None,
170+
inducing_point_allocator: Optional[InducingPointAllocator] = None,
165171
) -> None:
166172
r"""
167173
Args:
@@ -179,6 +185,9 @@ def __init__(
179185
VariationalStrategy). The default setting uses "whitening" of the
180186
variational distribution to make training easier.
181187
inducing_points: The number or specific locations of the inducing points.
188+
inducing_point_allocator: The `InducingPointAllocator` used to
189+
initialize the inducing point locations. If omitted,
190+
uses `GreedyVarianceReduction`.
182191
"""
183192
# We use the model subclass wrapper to deal with input / outcome transforms.
184193
# The number of outputs will be correct here due to the check in
@@ -209,14 +218,17 @@ def __init__(
209218
"covar_module.base_kernel.raw_lengthscale": -3,
210219
}
211220

212-
# initialize inducing points with a pivoted cholesky init if they are not given
221+
if inducing_point_allocator is None:
222+
inducing_point_allocator = GreedyVarianceReduction()
223+
224+
# initialize inducing points if they are not given
213225
if not isinstance(inducing_points, Tensor):
214226
if inducing_points is None:
215227
# number of inducing points is 25% the number of data points
216228
# as a heuristic
217229
inducing_points = int(0.25 * train_X.shape[-2])
218230

219-
inducing_points = _select_inducing_points(
231+
inducing_points = inducing_point_allocator.allocate_inducing_points(
220232
inputs=train_X,
221233
covar_module=covar_module,
222234
num_inducing=inducing_points,
@@ -255,8 +267,14 @@ def forward(self, X) -> MultivariateNormal:
255267

256268

257269
class SingleTaskVariationalGP(ApproximateGPyTorchModel):
258-
r"""A single-task variational GP model following [hensman2013svgp]_ with pivoted
259-
Cholesky initialization following [chen2018dpp]_ and [burt2020svgp]_.
270+
r"""A single-task variational GP model following [hensman2013svgp]_.
271+
272+
By default, the inducing points are initialized though the
273+
`GreedyVarianceReduction` of [burt2020svgp]_, which is known to be
274+
effective for building globally accurate models. However, custom
275+
inducing point allocators designed for specific down-stream tasks can also be
276+
provided (see [moss2023ipa]_ for details), e.g. `GreedyImprovementReduction`
277+
when the goal is to build a model suitable for standard BO.
260278
261279
A single-task variational GP using relatively strong priors on the Kernel
262280
hyperparameters, which work best when covariates are normalized to the unit
@@ -299,6 +317,7 @@ def __init__(
299317
inducing_points: Optional[Union[Tensor, int]] = None,
300318
outcome_transform: Optional[OutcomeTransform] = None,
301319
input_transform: Optional[InputTransform] = None,
320+
inducing_point_allocator: Optional[InducingPointAllocator] = None,
302321
) -> None:
303322
r"""
304323
Args:
@@ -319,6 +338,9 @@ def __init__(
319338
VariationalStrategy). The default setting uses "whitening" of the
320339
variational distribution to make training easier.
321340
inducing_points: The number or specific locations of the inducing points.
341+
inducing_point_allocator: The `InducingPointAllocator` used to
342+
initialize the inducing point locations. If omitted,
343+
uses `GreedyVarianceReduction`.
322344
"""
323345
with torch.no_grad():
324346
transformed_X = self.transform_inputs(
@@ -357,6 +379,19 @@ def __init__(
357379
else:
358380
self._is_custom_likelihood = True
359381

382+
if learn_inducing_points and (inducing_point_allocator is not None):
383+
warnings.warn(
384+
"After all the effort of specifying an inducing point allocator, "
385+
"you probably want to stop the inducing point locations "
386+
"being further optimized during the model fit. If so "
387+
"then set `learn_inducing_points` to False.",
388+
UserWarning,
389+
)
390+
391+
if inducing_point_allocator is None:
392+
inducing_point_allocator = GreedyVarianceReduction()
393+
self._inducing_point_allocator = inducing_point_allocator
394+
360395
model = _SingleTaskVariationalGP(
361396
train_X=transformed_X,
362397
train_Y=train_Y,
@@ -367,6 +402,7 @@ def __init__(
367402
variational_distribution=variational_distribution,
368403
variational_strategy=variational_strategy,
369404
inducing_points=inducing_points,
405+
inducing_point_allocator=self._inducing_point_allocator,
370406
)
371407

372408
super().__init__(model=model, likelihood=likelihood, num_outputs=num_outputs)
@@ -390,7 +426,7 @@ def init_inducing_points(
390426
) -> Tensor:
391427
r"""
392428
Reinitialize the inducing point locations in-place with the current kernel
393-
applied to `inputs`.
429+
applied to `inputs` through the model's inducing point allocation strategy.
394430
The variational distribution and variational strategy caches are reset.
395431
396432
Args:
@@ -407,7 +443,7 @@ def init_inducing_points(
407443

408444
with torch.no_grad():
409445
num_inducing = var_strat.inducing_points.size(-2)
410-
inducing_points = _select_inducing_points(
446+
inducing_points = self._inducing_point_allocator.allocate_inducing_points(
411447
inputs=inputs,
412448
covar_module=self.model.covar_module,
413449
num_inducing=num_inducing,
@@ -417,131 +453,3 @@ def init_inducing_points(
417453
var_strat.variational_params_initialized.fill_(0)
418454

419455
return inducing_points
420-
421-
422-
def _select_inducing_points(
423-
inputs: Tensor,
424-
covar_module: Module,
425-
num_inducing: int,
426-
input_batch_shape: torch.Size,
427-
) -> Tensor:
428-
r"""
429-
Utility function that evaluates a kernel at given inputs and selects inducing point
430-
locations based on the pivoted Cholesky heuristic.
431-
432-
Args:
433-
inputs: A (*batch_shape, n, d)-dim input data tensor.
434-
covar_module: GPyTorch Module returning a LinearOperator kernel matrix.
435-
num_inducing: The maximun number (m) of inducing points (m <= n).
436-
input_batch_shape: The non-task-related batch shape.
437-
438-
Returns:
439-
A (*batch_shape, m, d)-dim tensor of inducing point locations.
440-
"""
441-
442-
train_train_kernel = covar_module(inputs).evaluate_kernel()
443-
444-
# base case
445-
if train_train_kernel.ndimension() == 2:
446-
inducing_points = _pivoted_cholesky_init(
447-
train_inputs=inputs,
448-
kernel_matrix=train_train_kernel,
449-
max_length=num_inducing,
450-
)
451-
# multi-task case
452-
elif train_train_kernel.ndimension() == 3 and len(input_batch_shape) == 0:
453-
input_element = inputs[0] if inputs.ndimension() == 3 else inputs
454-
kernel_element = train_train_kernel[0]
455-
inducing_points = _pivoted_cholesky_init(
456-
train_inputs=input_element,
457-
kernel_matrix=kernel_element,
458-
max_length=num_inducing,
459-
)
460-
# batched input cases
461-
else:
462-
batched_inputs = (
463-
inputs.expand(*input_batch_shape, -1, -1)
464-
if inputs.ndimension() == 2
465-
else inputs
466-
)
467-
reshaped_inputs = batched_inputs.flatten(end_dim=-3)
468-
inducing_points = []
469-
for input_element in reshaped_inputs:
470-
# the extra kernel evals are a little wasteful but make it
471-
# easier to infer the task batch size
472-
kernel_element = covar_module(input_element).evaluate_kernel()
473-
# handle extra task batch dimension
474-
kernel_element = (
475-
kernel_element[0]
476-
if kernel_element.ndimension() == 3
477-
else kernel_element
478-
)
479-
inducing_points.append(
480-
_pivoted_cholesky_init(
481-
train_inputs=input_element,
482-
kernel_matrix=kernel_element,
483-
max_length=num_inducing,
484-
)
485-
)
486-
inducing_points = torch.stack(inducing_points).view(
487-
*input_batch_shape, num_inducing, -1
488-
)
489-
490-
return inducing_points
491-
492-
493-
def _pivoted_cholesky_init(
494-
train_inputs: Tensor,
495-
kernel_matrix: Union[Tensor, LinearOperator],
496-
max_length: int,
497-
epsilon: float = 1e-6,
498-
) -> Tensor:
499-
r"""
500-
A pivoted cholesky initialization method for the inducing points,
501-
originally proposed in [burt2020svgp]_ with the algorithm itself coming from
502-
[chen2018dpp]_. Code is a PyTorch version from [chen2018dpp]_, copied from
503-
https://github.com/laming-chen/fast-map-dpp/blob/master/dpp.py.
504-
505-
Args:
506-
train_inputs: training inputs (of shape n x d)
507-
kernel_matrix: kernel matrix on the training
508-
inputs
509-
max_length: number of inducing points to initialize
510-
epsilon: numerical jitter for stability.
511-
512-
Returns:
513-
max_length x d tensor of the training inputs corresponding to the top
514-
max_length pivots of the training kernel matrix
515-
"""
516-
517-
# this is numerically equivalent to iteratively performing a pivoted cholesky
518-
# while storing the diagonal pivots at each iteration
519-
# TODO: use gpytorch's pivoted cholesky instead once that gets an exposed list
520-
# TODO: ensure this works in batch mode, which it does not currently.
521-
522-
item_size = kernel_matrix.shape[-2]
523-
cis = torch.zeros(
524-
(max_length, item_size), device=kernel_matrix.device, dtype=kernel_matrix.dtype
525-
)
526-
di2s = kernel_matrix.diag()
527-
selected_items = []
528-
selected_item = torch.argmax(di2s)
529-
selected_items.append(selected_item)
530-
531-
while len(selected_items) < max_length:
532-
k = len(selected_items) - 1
533-
ci_optimal = cis[:k, selected_item]
534-
di_optimal = torch.sqrt(di2s[selected_item])
535-
elements = kernel_matrix[..., selected_item, :]
536-
eis = (elements - torch.matmul(ci_optimal, cis[:k, :])) / di_optimal
537-
cis[k, :] = eis
538-
di2s = di2s - eis.pow(2.0)
539-
di2s[selected_item] = NEG_INF
540-
selected_item = torch.argmax(di2s)
541-
if di2s[selected_item] < epsilon:
542-
break
543-
selected_items.append(selected_item)
544-
545-
ind_points = train_inputs[torch.stack(selected_items)]
546-
547-
return ind_points

0 commit comments

Comments
 (0)