18
18
import torch
19
19
from botorch .acquisition .objective import PosteriorTransform
20
20
from botorch .exceptions .errors import UnsupportedError
21
-
22
21
from botorch .logging import logger
23
22
from botorch .models .model import Model
24
23
from botorch .models .transforms .input import InputTransform
24
+ from botorch .utils .transforms import match_batch_shape
25
25
from botorch_community .models .utils .prior_fitted_network import (
26
26
download_model ,
27
27
ModelPaths ,
28
28
)
29
29
from botorch_community .posteriors .riemann import BoundedRiemannPosterior
30
+ from gpytorch .likelihoods .gaussian_likelihood import FixedNoiseGaussianLikelihood
30
31
from pfns .train import MainConfig # @manual=//pytorch/PFNs:PFNs
31
32
from torch import Tensor
32
33
from torch .nn import Module
@@ -58,7 +59,7 @@ def __init__(
58
59
59
60
Args:
60
61
train_X: A `n x d` tensor of training features.
61
- train_Y: A `n x m ` tensor of training observations.
62
+ train_Y: A `n x 1 ` tensor of training observations.
62
63
model: A pre-trained PFN model with the following
63
64
forward(train_X, train_Y, X) -> logit predictions of shape
64
65
`n x b x c` where c is the number of discrete buckets
@@ -95,40 +96,35 @@ def __init__(
95
96
if train_Yvar is not None :
96
97
logger .debug ("train_Yvar provided but ignored for PFNModel." )
97
98
98
- if not ( 1 <= train_Y .dim () <= 3 ) :
99
- raise UnsupportedError ("train_Y must be 1- to 3 -dimensional." )
99
+ if train_Y .dim () != 2 :
100
+ raise UnsupportedError ("train_Y must be 2 -dimensional." )
100
101
101
- if not ( 2 <= train_X .dim () <= 3 ) :
102
- raise UnsupportedError ("train_X must be 2- to 3- dimensional." )
102
+ if train_X .dim () != 2 :
103
+ raise UnsupportedError ("train_X must be 2-dimensional." )
103
104
104
- if train_Y .dim () == train_X .dim ():
105
- if train_Y .shape [- 1 ] > 1 :
106
- raise UnsupportedError ("Only 1 target allowed for PFNModel." )
107
- train_Y = train_Y .squeeze (- 1 )
105
+ if train_Y .shape [- 1 ] > 1 :
106
+ raise UnsupportedError ("Only 1 target allowed for PFNModel." )
108
107
109
- if (len (train_X .shape ) != len (train_Y .shape ) + 1 ) or (
110
- train_Y .shape != train_X .shape [:- 1 ]
111
- ):
108
+ if train_X .shape [0 ] != train_Y .shape [0 ]:
112
109
raise UnsupportedError (
113
- "train_X and train_Y must have the same shape except "
114
- "for the last dimension."
110
+ "train_X and train_Y must have the same number of rows."
115
111
)
116
112
117
- if len (train_X .shape ) == 2 :
118
- # adding batch dimension
119
- train_X = train_X .unsqueeze (0 )
120
- train_Y = train_Y .unsqueeze (0 )
121
-
122
113
with torch .no_grad ():
123
114
self .transformed_X = self .transform_inputs (
124
115
X = train_X , input_transform = input_transform
125
116
)
126
117
127
- self .train_X = train_X # shape: `b x n x d`
128
- self .train_Y = train_Y # shape: `b x n`
129
- self .pfn = model .to (train_X .device )
118
+ self .train_X = train_X # shape: (n, d)
119
+ self .train_Y = train_Y # shape: (n, 1)
120
+ # Downstream botorch tooling expects a likelihood to be specified,
121
+ # so here we use a FixedNoiseGaussianLikelihood that is unused.
122
+ if train_Yvar is None :
123
+ train_Yvar = torch .zeros_like (train_Y )
124
+ self .likelihood = FixedNoiseGaussianLikelihood (noise = train_Yvar )
125
+ self .pfn = model .to (device = train_X .device )
130
126
self .batch_first = batch_first
131
- self .constant_model_kwargs = constant_model_kwargs
127
+ self .constant_model_kwargs = constant_model_kwargs or {}
132
128
if input_transform is not None :
133
129
self .input_transform = input_transform
134
130
@@ -146,23 +142,19 @@ def posterior(
146
142
any `model.forward` or `model.likelihood` calls.
147
143
148
144
Args:
149
- X: A `b'? x b? x q x d`-dim Tensor, where `d` is the dimension of the
150
- feature space, `q` is the number of points considered jointly,
151
- and `b` is the batch dimension.
152
- We only allow `q=1` for PFNModel, so q can also be omitted, i.e.
153
- `b x d`-dim Tensor.
154
- **Currently not supported for PFNModel**.
145
+ X: A b? x q? x d`-dim Tensor, where `d` is the dimension of the
146
+ feature space.
155
147
output_indices: **Currenlty not supported for PFNModel.**
156
148
observation_noise: **Currently not supported for PFNModel**.
157
149
posterior_transform: **Currently not supported for PFNModel**.
158
150
159
151
Returns:
160
- A `BoundedRiemannPosterior` object , representing a batch of `b` joint
161
- distributions over `q` points and `m` outputs each .
152
+ A `BoundedRiemannPosterior`, representing a batch of b? x q?`
153
+ distributions.
162
154
"""
163
155
self .pfn .eval ()
164
156
if output_indices is not None :
165
- raise RuntimeError (
157
+ raise UnsupportedError (
166
158
"output_indices is not None. PFNModel should not "
167
159
"be a multi-output model."
168
160
)
@@ -173,60 +165,54 @@ def posterior(
173
165
if posterior_transform is not None :
174
166
raise UnsupportedError ("posterior_transform is not supported for PFNModel." )
175
167
176
- if not (1 <= len (X .shape ) <= 4 ):
177
- raise UnsupportedError ("X must be 1- to 4-dimensional." )
178
-
179
- # X has shape b'? x b? x q? x d
180
-
181
- orig_X_shape = X .shape
182
- q_in_orig_X_shape = len (X .shape ) > 2
183
-
184
- if len (X .shape ) == 1 :
185
- X = X .unsqueeze (0 ).unsqueeze (0 ).unsqueeze (0 ) # shape `b'=1 x b=1 x q=1 x d`
186
- elif len (X .shape ) == 2 :
187
- X = X .unsqueeze (1 ).unsqueeze (1 ) # shape `b' x b=1 x q=1 x d`
188
- elif len (X .shape ) == 3 :
189
- if self .train_X .shape [0 ] == 1 :
190
- X = X .unsqueeze (1 ) # shape `b' x b=1 x q x d`
191
- else :
192
- X = X .unsqueeze (0 ) # shape `b'=1 x b x q x d`
193
-
194
- # X has shape `b' x b x q x d`
195
-
196
- if X .shape [2 ] != 1 :
197
- raise UnsupportedError ("Only q=1 is supported for PFNModel." )
198
-
199
- # X has shape `b' x b x q=1 x d`
200
- X = self .transform_inputs (X )
201
- train_X = self .transformed_X # shape `b x n x d`
202
- train_Y = self .train_Y # shape `b x n`
203
- folded_X = X .transpose (0 , 2 ).squeeze (0 ) # shape `b x b' x d
204
-
205
- constant_model_kwargs = self .constant_model_kwargs or {}
206
-
207
- if self .batch_first :
208
- logits = self .pfn (
209
- train_X .float (),
210
- train_X .float (),
211
- folded_X .float (),
212
- ** constant_model_kwargs ,
213
- ).transpose (0 , 1 )
214
- else :
215
- logits = self .pfn (
216
- train_X .float ().transpose (0 , 1 ),
217
- train_Y .float ().transpose (0 , 1 ),
218
- folded_X .float ().transpose (0 , 1 ),
219
- ** constant_model_kwargs ,
220
- )
221
-
222
- # logits shape `b' x b x logits_dim`
168
+ orig_X_shape = X .shape # X has shape b? x q? x d
169
+ X = self .prepare_X (X ) # shape (b, q, d)
170
+ train_X = match_batch_shape (self .transformed_X , X ) # shape (b, n, d)
171
+ train_Y = match_batch_shape (self .train_Y , X ) # shape (b, n, 1)
223
172
224
- logits = logits .view (
173
+ probabilities = self .pfn_predict (
174
+ X = X , train_X = train_X , train_Y = train_Y
175
+ ) # (b, q, num_buckets)
176
+ probabilities = probabilities .view (
225
177
* orig_X_shape [:- 1 ], - 1
226
- ) # orig shape w/o q but logits_dim at end: `b'? x b? x q? x logits_dim`
227
- if q_in_orig_X_shape :
228
- logits = logits .squeeze (- 2 ) # shape `b'? x b? x logits_dim`
178
+ ) # (b?, q?, num_buckets)
229
179
230
- probabilities = logits .softmax (dim = - 1 )
180
+ # Get posterior with the right dtype
181
+ borders = self .pfn .criterion .borders .to (X .dtype )
182
+ return BoundedRiemannPosterior (
183
+ borders = borders ,
184
+ probabilities = probabilities ,
185
+ )
231
186
232
- return BoundedRiemannPosterior (self .pfn .criterion .borders , probabilities )
187
+ def prepare_X (self , X : Tensor ) -> Tensor :
188
+ if len (X .shape ) > 3 :
189
+ raise UnsupportedError (f"X must be at most 3-d, got { X .shape } ." )
190
+ while len (X .shape ) < 3 :
191
+ X = X .unsqueeze (0 )
192
+
193
+ X = self .transform_inputs (X ) # shape (b , q, d)
194
+ return X
195
+
196
+ def pfn_predict (self , X : Tensor , train_X : Tensor , train_Y : Tensor ) -> Tensor :
197
+ """
198
+ X has shape (b, q, d)
199
+ train_X has shape (b, n, d)
200
+ train_Y has shape (b, n, 1)
201
+ """
202
+ if not self .batch_first :
203
+ X = X .transpose (0 , 1 ) # shape (q, b, d)
204
+ train_X = train_X .transpose (0 , 1 ) # shape (n, b, d)
205
+ train_Y = train_Y .transpose (0 , 1 ) # shape (n, b, 1)
206
+
207
+ logits = self .pfn (
208
+ train_X .float (),
209
+ train_Y .float (),
210
+ X .float (),
211
+ ** self .constant_model_kwargs ,
212
+ )
213
+ if not self .batch_first :
214
+ logits = logits .transpose (0 , 1 ) # shape (b, q, num_buckets)
215
+ logits = logits .to (X .dtype )
216
+
217
+ probabilities = logits .softmax (dim = - 1 ) # shape (b, q, num_buckets)
218
+ return probabilities
0 commit comments