Skip to content

Commit 48e4ab5

Browse files
blethammeta-codesync[bot]
authored andcommitted
Add MCAcquisition support to PFN (#3031)
Summary: Pull Request resolved: #3031 Makes all of the necessary changes to use MCAcquisitions, particularly qLogNoisyExpectedImprovement, with PFNModel and its BoundedRiemannPosterior, and to use PFN with multiple metrics in MBM. Including: * Redo batch handling in PFNModel.posterior to do the right thing with the batch shapes sent in by Botorch MCAcquisition * Change BoundedRiemannPosterior.rsample_from_base_samples to use N(0, I) base samples, and thus work with Botorch Samplers, and register the appropriate sampler. * Add multi-batch handling to BoundedRiemannPosterior.icdf. * MBM uses a ModelList if models are not GPs. The PFNModel does not correctly handle q-batches yet. When MCAcquisition asks for a q-batch, it treats it the same as a t-batch, and in particular, the posteriors for each point in the q-batch are independent. This means that: * Batch optimization does not work correctly. It will run and generate multiple points, but those points won't actually be properly conditioned on the earlier points in the batch. * Handling of the incumbent best for EI in the noisy case is not correctly handled either, in that we are using an independent estimate for f_best rather than one jointly sampled. This is equivalent to using a plug-in estimate for f_best, which is the current behavior of DiscretizedExpectedImprovement. Both of these will need to be addressed in the future, but I think this diff where things at least run end-to-end will be an easier starting point for that work. Reviewed By: SamuelGabriel Differential Revision: D79667144 fbshipit-source-id: 3eaad4f40e632e2f2c2848ff88429bfc83eb41c3
1 parent 2591f32 commit 48e4ab5

File tree

4 files changed

+198
-149
lines changed

4 files changed

+198
-149
lines changed

botorch_community/models/prior_fitted_network.py

Lines changed: 72 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
import torch
1919
from botorch.acquisition.objective import PosteriorTransform
2020
from botorch.exceptions.errors import UnsupportedError
21-
2221
from botorch.logging import logger
2322
from botorch.models.model import Model
2423
from botorch.models.transforms.input import InputTransform
24+
from botorch.utils.transforms import match_batch_shape
2525
from botorch_community.models.utils.prior_fitted_network import (
2626
download_model,
2727
ModelPaths,
2828
)
2929
from botorch_community.posteriors.riemann import BoundedRiemannPosterior
30+
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
3031
from pfns.train import MainConfig # @manual=//pytorch/PFNs:PFNs
3132
from torch import Tensor
3233
from torch.nn import Module
@@ -58,7 +59,7 @@ def __init__(
5859
5960
Args:
6061
train_X: A `n x d` tensor of training features.
61-
train_Y: A `n x m` tensor of training observations.
62+
train_Y: A `n x 1` tensor of training observations.
6263
model: A pre-trained PFN model with the following
6364
forward(train_X, train_Y, X) -> logit predictions of shape
6465
`n x b x c` where c is the number of discrete buckets
@@ -95,40 +96,35 @@ def __init__(
9596
if train_Yvar is not None:
9697
logger.debug("train_Yvar provided but ignored for PFNModel.")
9798

98-
if not (1 <= train_Y.dim() <= 3):
99-
raise UnsupportedError("train_Y must be 1- to 3-dimensional.")
99+
if train_Y.dim() != 2:
100+
raise UnsupportedError("train_Y must be 2-dimensional.")
100101

101-
if not (2 <= train_X.dim() <= 3):
102-
raise UnsupportedError("train_X must be 2- to 3-dimensional.")
102+
if train_X.dim() != 2:
103+
raise UnsupportedError("train_X must be 2-dimensional.")
103104

104-
if train_Y.dim() == train_X.dim():
105-
if train_Y.shape[-1] > 1:
106-
raise UnsupportedError("Only 1 target allowed for PFNModel.")
107-
train_Y = train_Y.squeeze(-1)
105+
if train_Y.shape[-1] > 1:
106+
raise UnsupportedError("Only 1 target allowed for PFNModel.")
108107

109-
if (len(train_X.shape) != len(train_Y.shape) + 1) or (
110-
train_Y.shape != train_X.shape[:-1]
111-
):
108+
if train_X.shape[0] != train_Y.shape[0]:
112109
raise UnsupportedError(
113-
"train_X and train_Y must have the same shape except "
114-
"for the last dimension."
110+
"train_X and train_Y must have the same number of rows."
115111
)
116112

117-
if len(train_X.shape) == 2:
118-
# adding batch dimension
119-
train_X = train_X.unsqueeze(0)
120-
train_Y = train_Y.unsqueeze(0)
121-
122113
with torch.no_grad():
123114
self.transformed_X = self.transform_inputs(
124115
X=train_X, input_transform=input_transform
125116
)
126117

127-
self.train_X = train_X # shape: `b x n x d`
128-
self.train_Y = train_Y # shape: `b x n`
129-
self.pfn = model.to(train_X.device)
118+
self.train_X = train_X # shape: (n, d)
119+
self.train_Y = train_Y # shape: (n, 1)
120+
# Downstream botorch tooling expects a likelihood to be specified,
121+
# so here we use a FixedNoiseGaussianLikelihood that is unused.
122+
if train_Yvar is None:
123+
train_Yvar = torch.zeros_like(train_Y)
124+
self.likelihood = FixedNoiseGaussianLikelihood(noise=train_Yvar)
125+
self.pfn = model.to(device=train_X.device)
130126
self.batch_first = batch_first
131-
self.constant_model_kwargs = constant_model_kwargs
127+
self.constant_model_kwargs = constant_model_kwargs or {}
132128
if input_transform is not None:
133129
self.input_transform = input_transform
134130

@@ -146,23 +142,19 @@ def posterior(
146142
any `model.forward` or `model.likelihood` calls.
147143
148144
Args:
149-
X: A `b'? x b? x q x d`-dim Tensor, where `d` is the dimension of the
150-
feature space, `q` is the number of points considered jointly,
151-
and `b` is the batch dimension.
152-
We only allow `q=1` for PFNModel, so q can also be omitted, i.e.
153-
`b x d`-dim Tensor.
154-
**Currently not supported for PFNModel**.
145+
X: A b? x q? x d`-dim Tensor, where `d` is the dimension of the
146+
feature space.
155147
output_indices: **Currenlty not supported for PFNModel.**
156148
observation_noise: **Currently not supported for PFNModel**.
157149
posterior_transform: **Currently not supported for PFNModel**.
158150
159151
Returns:
160-
A `BoundedRiemannPosterior` object, representing a batch of `b` joint
161-
distributions over `q` points and `m` outputs each.
152+
A `BoundedRiemannPosterior`, representing a batch of b? x q?`
153+
distributions.
162154
"""
163155
self.pfn.eval()
164156
if output_indices is not None:
165-
raise RuntimeError(
157+
raise UnsupportedError(
166158
"output_indices is not None. PFNModel should not "
167159
"be a multi-output model."
168160
)
@@ -173,60 +165,54 @@ def posterior(
173165
if posterior_transform is not None:
174166
raise UnsupportedError("posterior_transform is not supported for PFNModel.")
175167

176-
if not (1 <= len(X.shape) <= 4):
177-
raise UnsupportedError("X must be 1- to 4-dimensional.")
178-
179-
# X has shape b'? x b? x q? x d
180-
181-
orig_X_shape = X.shape
182-
q_in_orig_X_shape = len(X.shape) > 2
183-
184-
if len(X.shape) == 1:
185-
X = X.unsqueeze(0).unsqueeze(0).unsqueeze(0) # shape `b'=1 x b=1 x q=1 x d`
186-
elif len(X.shape) == 2:
187-
X = X.unsqueeze(1).unsqueeze(1) # shape `b' x b=1 x q=1 x d`
188-
elif len(X.shape) == 3:
189-
if self.train_X.shape[0] == 1:
190-
X = X.unsqueeze(1) # shape `b' x b=1 x q x d`
191-
else:
192-
X = X.unsqueeze(0) # shape `b'=1 x b x q x d`
193-
194-
# X has shape `b' x b x q x d`
195-
196-
if X.shape[2] != 1:
197-
raise UnsupportedError("Only q=1 is supported for PFNModel.")
198-
199-
# X has shape `b' x b x q=1 x d`
200-
X = self.transform_inputs(X)
201-
train_X = self.transformed_X # shape `b x n x d`
202-
train_Y = self.train_Y # shape `b x n`
203-
folded_X = X.transpose(0, 2).squeeze(0) # shape `b x b' x d
204-
205-
constant_model_kwargs = self.constant_model_kwargs or {}
206-
207-
if self.batch_first:
208-
logits = self.pfn(
209-
train_X.float(),
210-
train_X.float(),
211-
folded_X.float(),
212-
**constant_model_kwargs,
213-
).transpose(0, 1)
214-
else:
215-
logits = self.pfn(
216-
train_X.float().transpose(0, 1),
217-
train_Y.float().transpose(0, 1),
218-
folded_X.float().transpose(0, 1),
219-
**constant_model_kwargs,
220-
)
221-
222-
# logits shape `b' x b x logits_dim`
168+
orig_X_shape = X.shape # X has shape b? x q? x d
169+
X = self.prepare_X(X) # shape (b, q, d)
170+
train_X = match_batch_shape(self.transformed_X, X) # shape (b, n, d)
171+
train_Y = match_batch_shape(self.train_Y, X) # shape (b, n, 1)
223172

224-
logits = logits.view(
173+
probabilities = self.pfn_predict(
174+
X=X, train_X=train_X, train_Y=train_Y
175+
) # (b, q, num_buckets)
176+
probabilities = probabilities.view(
225177
*orig_X_shape[:-1], -1
226-
) # orig shape w/o q but logits_dim at end: `b'? x b? x q? x logits_dim`
227-
if q_in_orig_X_shape:
228-
logits = logits.squeeze(-2) # shape `b'? x b? x logits_dim`
178+
) # (b?, q?, num_buckets)
229179

230-
probabilities = logits.softmax(dim=-1)
180+
# Get posterior with the right dtype
181+
borders = self.pfn.criterion.borders.to(X.dtype)
182+
return BoundedRiemannPosterior(
183+
borders=borders,
184+
probabilities=probabilities,
185+
)
231186

232-
return BoundedRiemannPosterior(self.pfn.criterion.borders, probabilities)
187+
def prepare_X(self, X: Tensor) -> Tensor:
188+
if len(X.shape) > 3:
189+
raise UnsupportedError(f"X must be at most 3-d, got {X.shape}.")
190+
while len(X.shape) < 3:
191+
X = X.unsqueeze(0)
192+
193+
X = self.transform_inputs(X) # shape (b , q, d)
194+
return X
195+
196+
def pfn_predict(self, X: Tensor, train_X: Tensor, train_Y: Tensor) -> Tensor:
197+
"""
198+
X has shape (b, q, d)
199+
train_X has shape (b, n, d)
200+
train_Y has shape (b, n, 1)
201+
"""
202+
if not self.batch_first:
203+
X = X.transpose(0, 1) # shape (q, b, d)
204+
train_X = train_X.transpose(0, 1) # shape (n, b, d)
205+
train_Y = train_Y.transpose(0, 1) # shape (n, b, 1)
206+
207+
logits = self.pfn(
208+
train_X.float(),
209+
train_Y.float(),
210+
X.float(),
211+
**self.constant_model_kwargs,
212+
)
213+
if not self.batch_first:
214+
logits = logits.transpose(0, 1) # shape (b, q, num_buckets)
215+
logits = logits.to(X.dtype)
216+
217+
probabilities = logits.softmax(dim=-1) # shape (b, q, num_buckets)
218+
return probabilities

botorch_community/posteriors/riemann.py

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414

1515
import torch
1616
from botorch.posteriors.posterior import Posterior
17+
from botorch.sampling.get_sampler import _get_sampler_mvn, GetSampler
18+
from botorch.sampling.normal import NormalMCSampler
1719
from torch import Tensor
1820

1921

2022
class BoundedRiemannPosterior(Posterior):
23+
batch_range = (0, -1)
24+
2125
"""
22-
Notes: Bounded posterior for now, will work on unbounded posteriors.
23-
This is also only over 1 test point, not batches.
26+
A single variate bounded Riemann posterior.
2427
"""
2528

2629
def __init__(self, borders, probabilities):
@@ -31,9 +34,9 @@ def __init__(self, borders, probabilities):
3134
borders, with each bucket having an associated probability.
3235
3336
Args:
34-
borders: A tensor of shape `(n_buckets + 1,)` defining the boundaries of
37+
borders: A tensor of shape `(num_buckets + 1,)` defining the boundaries of
3538
the buckets. Must be monotonically increasing.
36-
probabilities: A tensor of shape `(..., n_buckets,)` defining the
39+
probabilities: A tensor of shape `(b?, q?, num_buckets)` defining the
3740
probability mass in each bucket. Must sum to 1 in the last dim.
3841
"""
3942

@@ -79,18 +82,40 @@ def rsample(
7982
`self._extended_shape(sample_shape=sample_shape)`.
8083
"""
8184
sample_shape = sample_shape if sample_shape is not None else torch.Size([1])
82-
z = torch.rand(sample_shape)
83-
return self.rsample_from_base_samples(sample_shape, z)
85+
base_samples = torch.randn(
86+
sample_shape + self.probabilities.shape[:-1],
87+
device=self.probabilities.device,
88+
)
89+
return self.rsample_from_base_samples(
90+
sample_shape=sample_shape, base_samples=base_samples
91+
)
8492

8593
def rsample_from_base_samples(
86-
self, sample_shape: torch.Size, base_samples: Tensor
94+
self,
95+
sample_shape: torch.Size,
96+
base_samples: Tensor,
8797
) -> Tensor:
98+
"""
99+
base_samples are N(0, I) samples, as this posterior is registered
100+
with the IIDNormalSampler below. Alternatively it could be registered
101+
with a uniform sampler in which case the transformation to uniform RVs
102+
could be avoided. Shape of base_samples is (nsamp, b?, q).
103+
"""
88104
if base_samples.shape[: len(sample_shape)] != sample_shape:
89-
raise RuntimeError(
105+
raise ValueError(
90106
"`sample_shape` disagrees with shape of `base_samples`. "
91107
f"Got {sample_shape=} and {base_samples.shape=}."
92108
)
93-
return self.icdf(base_samples)
109+
# convert base samples from N(O, I) to Uniform.
110+
U = torch.distributions.Normal(0, 1).cdf(base_samples)
111+
# Convert U to Riemann samples.
112+
Z = self.icdf(U) # (nsamp, b?, q, 1)
113+
return Z
114+
115+
@property
116+
def base_sample_shape(self) -> torch.Size:
117+
r"""The shape of the base samples required to draw from the posterior."""
118+
return self.probabilities.shape[:-1]
94119

95120
@property
96121
def device(self) -> torch.device:
@@ -137,45 +162,62 @@ def confidence_region(
137162
Use .954 for 2 sigma of a normal distribution.
138163
"""
139164
side_probs = (1.0 - confidence_level) / 2
140-
return self.icdf(side_probs), self.icdf(1.0 - side_probs)
165+
lower = self.icdf(side_probs).squeeze()
166+
upper = self.icdf(1.0 - side_probs).squeeze()
167+
return lower, upper
141168

142-
def icdf(self, value: Union[Tensor, float]) -> Tensor:
169+
def icdf(
170+
self,
171+
value: Union[float, Tensor],
172+
) -> Tensor:
143173
r"""Inverse cdf (with gradients).
144174
Use value to get the index of the bucket that contains the value
145175
and then interpolate between the left and right borders of the bucket
146176
147177
Args:
148178
value: The value at which to evaluate the inverse CDF.
179+
Either a float, or a tensor with shape is (b', b?, q), where
180+
probabilities has shape (b?, q, num_buckets).
149181
150182
Returns:
151183
The inverse CDF of the posterior at the given value(s).
152-
The shape of the return is the shape of value, with the batch
153-
shape of the probs (all dims up to the final dim) appended
154-
with a final trailing dimension of 1, for the dim of the dist.
184+
The shape of the return is (b', b?, q, 1), with a trailing
185+
dimension.
155186
"""
187+
if not torch.is_tensor(value):
188+
# Promote to a (b', b?, q) tensor
189+
value = torch.tensor(value, device=self.device, dtype=self.dtype)
190+
value = value.expand(*self.probabilities.shape[:-1]).unsqueeze(0)
191+
value = value.movedim(0, -1) # (b?, q, b')
156192

157-
# final shape is (batch_shape, -1)
158-
value = torch.as_tensor(
159-
value, device=self.borders.device, dtype=self.borders.dtype
160-
)
161-
value_shape = value.shape
162-
# shape of cumprobs is (batch_shape, n_buckets)
163-
value = value.broadcast_to(size=(*self.cumprobs.shape[:-1], *value_shape))
164-
value = value.reshape(*self.cumprobs.shape[:-1], -1)
165-
166-
# get first index where cumprobs > value
167-
index = torch.searchsorted(self.cumprobs, value)
193+
index = torch.searchsorted(self.cumprobs, value) # (b?, q, b')
168194

169-
left_border = self.borders[index]
195+
left_border = self.borders[index] # (b?, q, b')
170196
right_border = self.borders[index + 1]
171197

172198
bucket_width = right_border - left_border
173199
right_cum_probs = torch.gather(self.cumprobs, -1, index)
174200
prob_width = torch.gather(self.probabilities, -1, index)
175201

176202
bucket_proportion_remaining = (right_cum_probs - value) / prob_width
177-
result = left_border + (1 - bucket_proportion_remaining) * bucket_width
178-
179-
# reshape to (value_shape, batch_shape, 1)
180-
result = result.transpose(0, -1)
181-
return result.reshape(*value_shape, *self.cumprobs.shape[:-1], 1)
203+
result = (
204+
right_border - bucket_proportion_remaining * bucket_width
205+
) # (b?, q, b')
206+
207+
# reshape back to (b', b?, q, 1)
208+
result = result.movedim(-1, 0).unsqueeze(-1)
209+
return result
210+
211+
212+
@GetSampler.register(BoundedRiemannPosterior)
213+
def _get_sampler_riemann(
214+
posterior: BoundedRiemannPosterior,
215+
sample_shape: torch.Size,
216+
*,
217+
seed: int | None = None,
218+
) -> NormalMCSampler:
219+
return _get_sampler_mvn(
220+
posterior=posterior,
221+
sample_shape=sample_shape,
222+
seed=seed,
223+
)

0 commit comments

Comments
 (0)