@@ -72,12 +72,15 @@ def _get_model(fixed_noise=False, use_octf=False, use_intf=False, **tkwargs):
72
72
73
73
74
74
class TestModelListGP (BotorchTestCase ):
75
- def test_ModelListGP (self ):
76
- for dtype , use_octf in itertools .product (
77
- (torch .float , torch .double ), (False , True )
78
- ):
75
+ def _base_test_ModelListGP (
76
+ self , fixed_noise : bool , dtype , use_octf : bool
77
+ ) -> ModelListGP :
78
+ # this is to make review easier -- will be removed in the next
79
+ # commit in the stack and never landed
80
+ unneccessary_condition_for_indentation_remove_me = True
81
+ if unneccessary_condition_for_indentation_remove_me :
79
82
tkwargs = {"device" : self .device , "dtype" : dtype }
80
- model = _get_model (use_octf = use_octf , ** tkwargs )
83
+ model = _get_model (fixed_noise = fixed_noise , use_octf = use_octf , ** tkwargs )
81
84
self .assertIsInstance (model , ModelListGP )
82
85
self .assertIsInstance (model .likelihood , LikelihoodList )
83
86
for m in model .models :
@@ -135,11 +138,6 @@ def test_ModelListGP(self):
135
138
expected_var = tmp_tf .untransform_posterior (p0_tf ).variance
136
139
self .assertTrue (torch .allclose (p0 .variance , expected_var ))
137
140
138
- # test observation_noise
139
- posterior = model .posterior (test_x , observation_noise = True )
140
- self .assertIsInstance (posterior , GPyTorchPosterior )
141
- self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
142
-
143
141
# test output_indices
144
142
posterior = model .posterior (
145
143
test_x , output_indices = [0 ], observation_noise = True
@@ -150,28 +148,35 @@ def test_ModelListGP(self):
150
148
# test condition_on_observations
151
149
f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
152
150
f_y = torch .rand (2 , 2 , ** tkwargs )
153
- cm = model .condition_on_observations (f_x , f_y )
151
+ if fixed_noise :
152
+ noise = 0.1 + 0.1 * torch .rand_like (f_y )
153
+ cond_kwargs = {"noise" : noise }
154
+ else :
155
+ cond_kwargs = {}
156
+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
154
157
self .assertIsInstance (cm , ModelListGP )
155
158
156
159
# test condition_on_observations batched
157
160
f_x = [torch .rand (3 , 2 , 1 , ** tkwargs ) for _ in range (2 )]
158
161
f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
159
- cm = model .condition_on_observations (f_x , f_y )
162
+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
160
163
self .assertIsInstance (cm , ModelListGP )
161
164
162
165
# test condition_on_observations batched (fast fantasies)
163
166
f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
164
167
f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
165
- cm = model .condition_on_observations (f_x , f_y )
168
+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
166
169
self .assertIsInstance (cm , ModelListGP )
167
170
168
171
# test condition_on_observations (incorrect input shape error)
169
172
with self .assertRaises (BotorchTensorDimensionError ):
170
- model .condition_on_observations (f_x , torch .rand (3 , 2 , 3 , ** tkwargs ))
173
+ model .condition_on_observations (
174
+ f_x , torch .rand (3 , 2 , 3 , ** tkwargs ), ** cond_kwargs
175
+ )
171
176
172
177
# test X having wrong size
173
178
with self .assertRaises (AssertionError ):
174
- cm = model .condition_on_observations (f_x [:1 ], f_y )
179
+ model .condition_on_observations (f_x [:1 ], f_y )
175
180
176
181
# test posterior transform
177
182
X = torch .rand (3 , 1 , ** tkwargs )
@@ -185,82 +190,37 @@ def test_ModelListGP(self):
185
190
)
186
191
)
187
192
188
- def test_ModelListGP_fixed_noise (self ):
193
+ return model
194
+
195
+ def test_ModelListGP (self ) -> None :
189
196
for dtype , use_octf in itertools .product (
190
197
(torch .float , torch .double ), (False , True )
191
198
):
192
- tkwargs = {"device" : self .device , "dtype" : dtype }
193
- model = _get_model (fixed_noise = True , use_octf = use_octf , ** tkwargs )
194
- self .assertIsInstance (model , ModelListGP )
195
- self .assertIsInstance (model .likelihood , LikelihoodList )
196
- for m in model .models :
197
- self .assertIsInstance (m .mean_module , ConstantMean )
198
- self .assertIsInstance (m .covar_module , ScaleKernel )
199
- matern_kernel = m .covar_module .base_kernel
200
- self .assertIsInstance (matern_kernel , MaternKernel )
201
- self .assertIsInstance (matern_kernel .lengthscale_prior , GammaPrior )
202
199
203
- # test model fitting
204
- mll = SumMarginalLogLikelihood (model .likelihood , model )
205
- for mll_ in mll .mlls :
206
- self .assertIsInstance (mll_ , ExactMarginalLogLikelihood )
207
- with warnings .catch_warnings ():
208
- warnings .filterwarnings ("ignore" , category = OptimizationWarning )
209
- mll = fit_gpytorch_mll (
210
- mll , optimizer_kwargs = {"options" : {"maxiter" : 1 }}, max_attempts = 1
211
- )
200
+ model = self ._base_test_ModelListGP (
201
+ fixed_noise = False , dtype = dtype , use_octf = use_octf
202
+ )
203
+ tkwargs = {"device" : self .device , "dtype" : dtype }
212
204
213
- # test posterior
205
+ # test observation_noise
214
206
test_x = torch .tensor ([[0.25 ], [0.75 ]], ** tkwargs )
215
- posterior = model .posterior (test_x )
207
+ posterior = model .posterior (test_x , observation_noise = True )
216
208
self .assertIsInstance (posterior , GPyTorchPosterior )
217
209
self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
218
- if use_octf :
219
- # ensure un-transformation is applied
220
- submodel = model .models [0 ]
221
- p0 = submodel .posterior (test_x )
222
- tmp_tf = submodel .outcome_transform
223
- del submodel .outcome_transform
224
- p0_tf = submodel .posterior (test_x )
225
- submodel .outcome_transform = tmp_tf
226
- expected_var = tmp_tf .untransform_posterior (p0_tf ).variance
227
- self .assertTrue (torch .allclose (p0 .variance , expected_var ))
228
210
229
- # test output_indices
230
- posterior = model .posterior (
231
- test_x , output_indices = [0 ], observation_noise = True
232
- )
233
- self .assertIsInstance (posterior , GPyTorchPosterior )
234
- self .assertIsInstance (posterior .distribution , MultivariateNormal )
211
+ def test_ModelListGP_fixed_noise (self ) -> None :
235
212
236
- # test condition_on_observations
213
+ for dtype , use_octf in itertools .product (
214
+ (torch .float , torch .double ), (False , True )
215
+ ):
216
+ model = self ._base_test_ModelListGP (
217
+ fixed_noise = True , dtype = dtype , use_octf = use_octf
218
+ )
219
+ tkwargs = {"device" : self .device , "dtype" : dtype }
237
220
f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
238
221
f_y = torch .rand (2 , 2 , ** tkwargs )
239
- noise = 0.1 + 0.1 * torch .rand_like (f_y )
240
- cm = model .condition_on_observations (f_x , f_y , noise = noise )
241
- self .assertIsInstance (cm , ModelListGP )
242
-
243
- # test condition_on_observations batched
244
- f_x = [torch .rand (3 , 2 , 1 , ** tkwargs ) for _ in range (2 )]
245
- f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
246
- noise = 0.1 + 0.1 * torch .rand_like (f_y )
247
- cm = model .condition_on_observations (f_x , f_y , noise = noise )
248
- self .assertIsInstance (cm , ModelListGP )
249
-
250
- # test condition_on_observations batched (fast fantasies)
251
- f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
252
- f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
253
- noise = 0.1 + 0.1 * torch .rand (2 , 2 , ** tkwargs )
254
- cm = model .condition_on_observations (f_x , f_y , noise = noise )
255
- self .assertIsInstance (cm , ModelListGP )
256
222
257
- # test condition_on_observations (incorrect input shape error)
258
- with self .assertRaises (BotorchTensorDimensionError ):
259
- model .condition_on_observations (
260
- f_x , torch .rand (3 , 2 , 3 , ** tkwargs ), noise = noise
261
- )
262
223
# test condition_on_observations (incorrect noise shape error)
263
- f_y = torch .rand (2 , 2 , ** tkwargs )
264
224
with self .assertRaises (BotorchTensorDimensionError ):
265
225
model .condition_on_observations (
266
226
f_x , f_y , noise = torch .rand (2 , 3 , ** tkwargs )
0 commit comments