Skip to content

Commit 14587f6

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Extend Posterior API to support torch distributions & overhaul MCSampler API (#1254)
Summary: X-link: facebook/Ax#1254 X-link: facebookresearch/aepsych#193 Pull Request resolved: #1486 The main goal here is to broadly support non-Gaussian posteriors. - Adds a generic `TorchPosterior` which wraps a Torch `Distribution`. This defines a few properties that we commonly expect, and calls the `distribution` for the rest. - For a unified plotting API, this shifts away from mean & variance to a quantile function. Most torch distributions implement inverse CDF, which is used as quantile. For others, the user should implement it either at distribution or posterior level. - Hands off the burden of base sample handling from the posterior to the samplers. Using a dispatcher based `get_sampler` method, we can support SAA with mixed posteriors without having to shuffle base samples in a `PosteriorList`, as long as all base distributions have a corresponding sampler and support base samples. - Adds `ListSampler` for sampling from `PosteriorList`. - Adds `ForkedRNGSampler` and `StochasticSampler` for sampling from posteriors without base samples. - Adds `rsample_from_base_samples` for sampling with `base_samples` / with a `sampler`. - Absorbs `FullyBayesianPosteriorList` into `PosteriorList`. - For MC acqfs, introduces a `get_posterior_samples` for sampling from the posterior with base samples / a sampler. If a sampler was not specified, this constructs the appropriate sampler for the posterior using `get_sampler`, eliminating the need to construct a sampler in `__init__`, which we used to do under the assumption of Gaussian posteriors. TODOs: - Relax the Gaussian assumption in acquisition functions & utilities. Some of this might be addressed in a follow-up diff. - Updates to website / docs & tutorials to clear up some of the Gaussian assumption, introduce the new relaxed API. Likely a follow-up diff. Other notables: - See D39760855 for usage of TorchDistribution in SkewGP. - TransformedPosterior could serve as the fallback option for derived posteriors. - MC samplers no longer support resample or collapse_batch_dims(=False). These can be handled by i) not using base samples, ii) just using torch.fork_rng and sampling without base samples from that. Samplers are only meant to support SAA. Introduces `ForkedRNGSampler` and `StochasticSampler` as convenience samplers for these use cases. - Introduced `batch_range_override` for the sampler to support edge cases where we may want to override `posterior.batch_range` (needed in `qMultiStepLookahead`) - Removes unused sampling utilities `construct_base_samples(_from_posterior)`, which assume Gaussian posterior. - Moves the main logic of `_set_sampler` method of CachedCholesky subclasses to a `_update_base_samples` method on samplers, and simplifies these classes a bit more. Reviewed By: Balandat Differential Revision: D39759489 fbshipit-source-id: f4db866320bab9a5455dfc0c2f7fe2cc15385453
1 parent 24ee799 commit 14587f6

File tree

101 files changed

+2811
-2889
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

101 files changed

+2811
-2889
lines changed

botorch/acquisition/acquisition.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
from abc import ABC, abstractmethod
1313
from typing import Callable, Optional
1414

15+
import torch
1516
from botorch.exceptions import BotorchWarning, UnsupportedError
1617
from botorch.models.model import Model
1718
from botorch.posteriors.posterior import Posterior
19+
from botorch.sampling.base import MCSampler
20+
from botorch.sampling.get_sampler import get_sampler
1821
from torch import Tensor
1922
from torch.nn import Module
2023

@@ -132,3 +135,36 @@ def extract_candidates(self, X_full: Tensor) -> Tensor:
132135
A `b x q x d`-dim Tensor with `b` t-batches of `q` design points each.
133136
"""
134137
pass # pragma: no cover
138+
139+
140+
class MCSamplerMixin(ABC):
141+
r"""A mix-in for adding sampler functionality into an acquisition function class.
142+
143+
Attributes:
144+
_default_sample_shape: The `sample_shape` for the default sampler.
145+
146+
:meta private:
147+
"""
148+
149+
_default_sample_shape = torch.Size([512])
150+
151+
def __init__(self, sampler: Optional[MCSampler] = None) -> None:
152+
r"""Register the sampler on the acquisition function.
153+
154+
Args:
155+
sampler: The sampler used to draw base samples for MC-based acquisition
156+
functions. If `None`, a sampler is generated using `get_sampler`.
157+
"""
158+
self.sampler = sampler
159+
160+
def get_posterior_samples(self, posterior: Posterior) -> Tensor:
161+
r"""Sample from the posterior using the sampler.
162+
163+
Args:
164+
posterior: The posterior to sample from.
165+
"""
166+
if self.sampler is None:
167+
self.sampler = get_sampler(
168+
posterior=posterior, sample_shape=self._default_sample_shape
169+
)
170+
return self.sampler(posterior=posterior)

botorch/acquisition/active_learning.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525

2626
from typing import Optional
2727

28+
import torch
2829
from botorch import settings
2930
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
3031
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
3132
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
3233
from botorch.models.model import Model
33-
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler
34+
from botorch.sampling.base import MCSampler
35+
from botorch.sampling.normal import SobolQMCNormalSampler
3436
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
3537
from torch import Tensor
3638

@@ -81,9 +83,7 @@ def __init__(
8183
# variance does not depend on the samples y (only on x), which is true for
8284
# standard GP models, but not in general (e.g. for other likelihoods or
8385
# heteroskedastic GPs using a separate noise model fit on data).
84-
sampler = SobolQMCNormalSampler(
85-
num_samples=1, resample=False, collapse_batch_dims=True
86-
)
86+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1]))
8787
self.sampler = sampler
8888
self.X_pending = X_pending
8989
self.register_buffer("mc_points", mc_points)
@@ -150,8 +150,6 @@ def __init__(
150150
two samples. Can be implemented via GenericMCObjective.
151151
sampler: The sampler used for drawing MC samples.
152152
"""
153-
if sampler is None:
154-
sampler = SobolQMCNormalSampler(num_samples=512, collapse_batch_dims=True)
155153
super().__init__(
156154
model=model, sampler=sampler, objective=objective, X_pending=None
157155
)
@@ -175,7 +173,7 @@ def forward(self, X: Tensor) -> Tensor:
175173
# The output is of shape batch_shape x 2 x d
176174
# For PairwiseGP, d = 1
177175
post = self.model.posterior(X)
178-
samples = self.sampler(post) # num_samples x batch_shape x 2 x d
176+
samples = self.get_posterior_samples(post) # num_samples x batch_shape x 2 x d
179177

180178
# The output is of shape num_samples x batch_shape x q/2 x d
181179
# assuming the comparison is made between the 2 * i and 2 * i + 1 elements

botorch/acquisition/analytic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from botorch.models.gp_regression import FixedNoiseGP
2323
from botorch.models.gpytorch import GPyTorchModel
2424
from botorch.models.model import Model
25-
from botorch.sampling.samplers import SobolQMCNormalSampler
2625
from botorch.utils.transforms import convert_to_target_pre_hook, t_batch_mode_transform
2726
from torch import Tensor
2827
from torch.distributions import Normal
@@ -561,9 +560,11 @@ def __init__(
561560
"Only FixedNoiseGPs are currently supported for fantasy NEI"
562561
)
563562
# sample fantasies
563+
from botorch.sampling.normal import SobolQMCNormalSampler
564+
564565
with torch.no_grad():
565566
posterior = model.posterior(X=X_observed)
566-
sampler = SobolQMCNormalSampler(num_fantasies)
567+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_fantasies]))
567568
Y_fantasized = sampler(posterior).squeeze(-1)
568569
batch_X_observed = X_observed.expand(num_fantasies, *X_observed.shape)
569570
# The fantasy model will operate in batch mode

botorch/acquisition/cached_cholesky.py

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,15 @@
1212

1313
import warnings
1414
from abc import ABC
15-
from typing import Optional
1615

1716
import torch
18-
from botorch.exceptions.errors import UnsupportedError
1917
from botorch.exceptions.warnings import BotorchWarning
20-
from botorch.models import HigherOrderGP
21-
from botorch.models.deterministic import DeterministicModel
18+
from botorch.models.gpytorch import GPyTorchModel
19+
from botorch.models.higher_order_gp import HigherOrderGP
2220
from botorch.models.model import Model, ModelList
2321
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
2422
from botorch.posteriors.gpytorch import GPyTorchPosterior
2523
from botorch.posteriors.posterior import Posterior
26-
from botorch.sampling.samplers import MCSampler
2724
from botorch.utils.low_rank import extract_batch_covar, sample_cached_cholesky
2825
from gpytorch import settings as gpt_settings
2926
from gpytorch.distributions.multitask_multivariate_normal import (
@@ -43,58 +40,32 @@ class CachedCholeskyMCAcquisitionFunction(ABC):
4340
:meta private:
4441
"""
4542

46-
def _check_sampler(self) -> None:
47-
r"""Check compatibility of sampler and model with a cached Cholesky."""
48-
if not self.sampler.collapse_batch_dims:
49-
raise UnsupportedError(
50-
"Expected sampler to use `collapse_batch_dims=True`."
51-
)
52-
elif self.sampler.base_samples is not None:
53-
warnings.warn(
54-
message=(
55-
"sampler.base_samples is not None. The base_samples must be "
56-
"initialized to None. Resetting sampler.base_samples to None."
57-
),
58-
category=BotorchWarning,
59-
)
60-
self.sampler.base_samples = None
61-
elif self._uses_matheron and self.sampler.batch_range != (0, -1):
62-
raise RuntimeError(
63-
"sampler.batch_range is not (0, -1). This check requires that the "
64-
"sampler.batch_range is (0, -1) with GPs that use Matheron's rule "
65-
"for sampling, in order to properly collapse batch dimensions. "
66-
)
67-
6843
def _setup(
6944
self,
7045
model: Model,
71-
sampler: Optional[MCSampler] = None,
7246
cache_root: bool = False,
73-
check_sampler: bool = False,
7447
) -> None:
7548
r"""Set class attributes and perform compatibility checks.
7649
7750
Args:
7851
model: A model.
79-
sampler: A sampler.
8052
cache_root: A boolean indicating whether to cache the Cholesky.
8153
This might be overridden in the model is not compatible.
82-
check_sampler: A boolean indicating whether to check the sampler.
83-
The sampler is always checked if cache_root is True.
8454
"""
8555
models = model.models if isinstance(model, ModelList) else [model]
86-
self._is_mt = any(
56+
if any(
8757
isinstance(m, (MultiTaskGP, KroneckerMultiTaskGP, HigherOrderGP))
58+
or not isinstance(m, GPyTorchModel)
8859
for m in models
89-
)
90-
self._is_deterministic = any(isinstance(m, DeterministicModel) for m in models)
91-
self._uses_matheron = any(
92-
isinstance(m, (KroneckerMultiTaskGP, HigherOrderGP)) for m in models
93-
)
94-
if check_sampler or cache_root:
95-
self._check_sampler()
96-
if self._is_deterministic or self._is_mt:
97-
cache_root = False
60+
):
61+
if cache_root:
62+
warnings.warn(
63+
"`cache_root` is only supported for GPyTorchModels (with the "
64+
f"exception of MultiTask models). Got model={model}. Setting "
65+
"`cache_root = False",
66+
RuntimeWarning,
67+
)
68+
cache_root = False
9869
self._cache_root = cache_root
9970

10071
def _compute_root_decomposition(
@@ -118,10 +89,10 @@ def _compute_root_decomposition(
11889
Args:
11990
posterior: The posterior over f(X_baseline).
12091
"""
121-
if isinstance(posterior.mvn, MultitaskMultivariateNormal):
122-
lazy_covar = extract_batch_covar(posterior.mvn)
92+
if isinstance(posterior.distribution, MultitaskMultivariateNormal):
93+
lazy_covar = extract_batch_covar(posterior.distribution)
12394
else:
124-
lazy_covar = posterior.mvn.lazy_covariance_matrix
95+
lazy_covar = posterior.distribution.lazy_covariance_matrix
12596
with gpt_settings.fast_computations.covar_root_decomposition(False):
12697
lazy_covar_root = lazy_covar.root_decomposition()
12798
return lazy_covar_root.root.to_dense()
@@ -142,7 +113,7 @@ def _get_f_X_samples(self, posterior: GPyTorchPosterior, q_in: int) -> Tensor:
142113
# cached covariance (and box decompositions) and the new block.
143114
# But recomputing box decompositions every time the jitter changes would
144115
# be quite slow.
145-
if not self._is_mt and self._cache_root and hasattr(self, "_baseline_L"):
116+
if self._cache_root and hasattr(self, "_baseline_L"):
146117
try:
147118
return sample_cached_cholesky(
148119
posterior=posterior,
@@ -160,7 +131,7 @@ def _get_f_X_samples(self, posterior: GPyTorchPosterior, q_in: int) -> Tensor:
160131
)
161132

162133
# TODO: improve efficiency for multi-task models
163-
samples = self.sampler(posterior)
134+
samples = self.get_posterior_samples(posterior)
164135
if isinstance(self.model, HigherOrderGP):
165136
# Select the correct q-batch dimension for HOGP.
166137
q_dim = -self.model._num_dimensions
@@ -170,3 +141,24 @@ def _get_f_X_samples(self, posterior: GPyTorchPosterior, q_in: int) -> Tensor:
170141
return samples.index_select(q_dim, q_idcs)
171142
else:
172143
return samples[..., -q_in:, :]
144+
145+
def _set_sampler(
146+
self,
147+
q_in: int,
148+
posterior: Posterior,
149+
) -> None:
150+
r"""Update the sampler to use the original base samples for X_baseline.
151+
152+
Args:
153+
q_in: The effective input batch size. This is typically equal to the
154+
q-batch size of `X`. However, if using a one-to-many input transform,
155+
e.g., `InputPerturbation` with `n_w` perturbations, the posterior will
156+
have `n_w` points on the q-batch for each point on the q-batch of `X`.
157+
In which case, `q_in = q * n_w` is used.
158+
posterior: The posterior.
159+
"""
160+
if self.q_in != q_in and self.base_sampler is not None:
161+
self.sampler._update_base_samples(
162+
posterior=posterior, base_sampler=self.base_sampler
163+
)
164+
self.q_in = q_in

botorch/acquisition/cost_aware.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from botorch.acquisition.objective import IdentityMCObjective, MCAcquisitionObjective
2121
from botorch.exceptions.warnings import CostAwareWarning
2222
from botorch.models.model import Model
23-
from botorch.sampling.samplers import MCSampler
23+
from botorch.sampling.base import MCSampler
2424
from torch import Tensor
2525
from torch.nn import Module
2626

botorch/acquisition/input_constructors.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@
8383
from botorch.exceptions.errors import UnsupportedError
8484
from botorch.models.cost import AffineFidelityCostModel
8585
from botorch.models.deterministic import FixedSingleSampleModel
86+
from botorch.models.gpytorch import GPyTorchModel
8687
from botorch.models.model import Model
8788
from botorch.optim.optimize import optimize_acqf
88-
from botorch.sampling.samplers import IIDNormalSampler, MCSampler, SobolQMCNormalSampler
89+
from botorch.sampling.base import MCSampler
90+
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
8991
from botorch.utils.constraints import get_outcome_constraint_transforms
9092
from botorch.utils.containers import BotorchContainer
9193
from botorch.utils.datasets import BotorchDataset, SupervisedDataset
@@ -643,10 +645,10 @@ def construct_inputs_qUCB(
643645
def _get_sampler(mc_samples: int, qmc: bool) -> MCSampler:
644646
"""Set up MC sampler for q(N)EHVI."""
645647
# initialize the sampler
646-
seed = int(torch.randint(1, 10000, (1,)).item())
648+
shape = torch.Size([mc_samples])
647649
if qmc:
648-
return SobolQMCNormalSampler(num_samples=mc_samples, seed=seed)
649-
return IIDNormalSampler(num_samples=mc_samples, seed=seed)
650+
return SobolQMCNormalSampler(sample_shape=shape)
651+
return IIDNormalSampler(sample_shape=shape)
650652

651653

652654
@acqf_input_constructor(ExpectedHypervolumeImprovement)
@@ -756,7 +758,7 @@ def construct_inputs_qEHVI(
756758
)
757759

758760
sampler = kwargs.get("sampler")
759-
if sampler is None:
761+
if sampler is None and isinstance(model, GPyTorchModel):
760762
sampler = _get_sampler(
761763
mc_samples=kwargs.get("mc_samples", 128), qmc=kwargs.get("qmc", True)
762764
)
@@ -806,7 +808,7 @@ def construct_inputs_qNEHVI(
806808
cons_tfs = get_outcome_constraint_transforms(outcome_constraints)
807809

808810
sampler = kwargs.get("sampler")
809-
if sampler is None:
811+
if sampler is None and isinstance(model, GPyTorchModel):
810812
sampler = _get_sampler(
811813
mc_samples=kwargs.get("mc_samples", 128), qmc=kwargs.get("qmc", True)
812814
)
@@ -1175,7 +1177,7 @@ def optimize_objective(
11751177
model=model,
11761178
objective=objective,
11771179
posterior_transform=posterior_transform,
1178-
sampler=sampler_cls(num_samples=mc_samples, seed=seed_inner),
1180+
sampler=sampler_cls(sample_shape=torch.Size([mc_samples]), seed=seed_inner),
11791181
)
11801182
else:
11811183
acq_function = PosteriorMean(

botorch/acquisition/knowledge_gradient.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from botorch import settings
3434
from botorch.acquisition.acquisition import (
3535
AcquisitionFunction,
36+
MCSamplerMixin,
3637
OneShotAcquisitionFunction,
3738
)
3839
from botorch.acquisition.analytic import PosteriorMean
@@ -41,7 +42,8 @@
4142
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
4243
from botorch.exceptions.errors import UnsupportedError
4344
from botorch.models.model import Model
44-
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler
45+
from botorch.sampling.base import MCSampler
46+
from botorch.sampling.normal import SobolQMCNormalSampler
4547
from botorch.utils.transforms import (
4648
concatenate_pending_points,
4749
match_batch_shape,
@@ -108,9 +110,7 @@ def __init__(
108110
"Must specify `num_fantasies` if no `sampler` is provided."
109111
)
110112
# base samples should be fixed for joint optimization over X, X_fantasies
111-
sampler = SobolQMCNormalSampler(
112-
num_samples=num_fantasies, resample=False, collapse_batch_dims=True
113-
)
113+
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_fantasies]))
114114
elif num_fantasies is not None:
115115
if sampler.sample_shape != torch.Size([num_fantasies]):
116116
raise ValueError(
@@ -119,11 +119,10 @@ def __init__(
119119
else:
120120
num_fantasies = sampler.sample_shape[0]
121121
super(MCAcquisitionFunction, self).__init__(model=model)
122+
MCSamplerMixin.__init__(self, sampler=sampler)
122123
# if not explicitly specified, we use the posterior mean for linear objs
123124
if isinstance(objective, MCAcquisitionObjective) and inner_sampler is None:
124-
inner_sampler = SobolQMCNormalSampler(
125-
num_samples=128, resample=False, collapse_batch_dims=True
126-
)
125+
inner_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([128]))
127126
elif objective is not None and not isinstance(
128127
objective, MCAcquisitionObjective
129128
):
@@ -150,7 +149,6 @@ def __init__(
150149
"If using a multi-output model without an objective, "
151150
"posterior_transform must scalarize the output."
152151
)
153-
self.sampler: MCSampler = sampler
154152
self.objective = objective
155153
self.posterior_transform = posterior_transform
156154
self.set_X_pending(X_pending)

0 commit comments

Comments
 (0)