@@ -72,12 +72,15 @@ def _get_model(fixed_noise=False, use_octf=False, use_intf=False, **tkwargs):
7272
7373
7474class TestModelListGP (BotorchTestCase ):
75- def test_ModelListGP (self ):
76- for dtype , use_octf in itertools .product (
77- (torch .float , torch .double ), (False , True )
78- ):
75+ def _base_test_ModelListGP (
76+ self , fixed_noise : bool , dtype , use_octf : bool
77+ ) -> ModelListGP :
78+ # this is to make review easier -- will be removed in the next
79+ # commit in the stack and never landed
80+ unneccessary_condition_for_indentation_remove_me = True
81+ if unneccessary_condition_for_indentation_remove_me :
7982 tkwargs = {"device" : self .device , "dtype" : dtype }
80- model = _get_model (use_octf = use_octf , ** tkwargs )
83+ model = _get_model (fixed_noise = fixed_noise , use_octf = use_octf , ** tkwargs )
8184 self .assertIsInstance (model , ModelListGP )
8285 self .assertIsInstance (model .likelihood , LikelihoodList )
8386 for m in model .models :
@@ -135,11 +138,6 @@ def test_ModelListGP(self):
135138 expected_var = tmp_tf .untransform_posterior (p0_tf ).variance
136139 self .assertTrue (torch .allclose (p0 .variance , expected_var ))
137140
138- # test observation_noise
139- posterior = model .posterior (test_x , observation_noise = True )
140- self .assertIsInstance (posterior , GPyTorchPosterior )
141- self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
142-
143141 # test output_indices
144142 posterior = model .posterior (
145143 test_x , output_indices = [0 ], observation_noise = True
@@ -150,28 +148,35 @@ def test_ModelListGP(self):
150148 # test condition_on_observations
151149 f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
152150 f_y = torch .rand (2 , 2 , ** tkwargs )
153- cm = model .condition_on_observations (f_x , f_y )
151+ if fixed_noise :
152+ noise = 0.1 + 0.1 * torch .rand_like (f_y )
153+ cond_kwargs = {"noise" : noise }
154+ else :
155+ cond_kwargs = {}
156+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
154157 self .assertIsInstance (cm , ModelListGP )
155158
156159 # test condition_on_observations batched
157160 f_x = [torch .rand (3 , 2 , 1 , ** tkwargs ) for _ in range (2 )]
158161 f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
159- cm = model .condition_on_observations (f_x , f_y )
162+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
160163 self .assertIsInstance (cm , ModelListGP )
161164
162165 # test condition_on_observations batched (fast fantasies)
163166 f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
164167 f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
165- cm = model .condition_on_observations (f_x , f_y )
168+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
166169 self .assertIsInstance (cm , ModelListGP )
167170
168171 # test condition_on_observations (incorrect input shape error)
169172 with self .assertRaises (BotorchTensorDimensionError ):
170- model .condition_on_observations (f_x , torch .rand (3 , 2 , 3 , ** tkwargs ))
173+ model .condition_on_observations (
174+ f_x , torch .rand (3 , 2 , 3 , ** tkwargs ), ** cond_kwargs
175+ )
171176
172177 # test X having wrong size
173178 with self .assertRaises (AssertionError ):
174- cm = model .condition_on_observations (f_x [:1 ], f_y )
179+ model .condition_on_observations (f_x [:1 ], f_y )
175180
176181 # test posterior transform
177182 X = torch .rand (3 , 1 , ** tkwargs )
@@ -185,82 +190,37 @@ def test_ModelListGP(self):
185190 )
186191 )
187192
188- def test_ModelListGP_fixed_noise (self ):
193+ return model
194+
195+ def test_ModelListGP (self ) -> None :
189196 for dtype , use_octf in itertools .product (
190197 (torch .float , torch .double ), (False , True )
191198 ):
192- tkwargs = {"device" : self .device , "dtype" : dtype }
193- model = _get_model (fixed_noise = True , use_octf = use_octf , ** tkwargs )
194- self .assertIsInstance (model , ModelListGP )
195- self .assertIsInstance (model .likelihood , LikelihoodList )
196- for m in model .models :
197- self .assertIsInstance (m .mean_module , ConstantMean )
198- self .assertIsInstance (m .covar_module , ScaleKernel )
199- matern_kernel = m .covar_module .base_kernel
200- self .assertIsInstance (matern_kernel , MaternKernel )
201- self .assertIsInstance (matern_kernel .lengthscale_prior , GammaPrior )
202199
203- # test model fitting
204- mll = SumMarginalLogLikelihood (model .likelihood , model )
205- for mll_ in mll .mlls :
206- self .assertIsInstance (mll_ , ExactMarginalLogLikelihood )
207- with warnings .catch_warnings ():
208- warnings .filterwarnings ("ignore" , category = OptimizationWarning )
209- mll = fit_gpytorch_mll (
210- mll , optimizer_kwargs = {"options" : {"maxiter" : 1 }}, max_attempts = 1
211- )
200+ model = self ._base_test_ModelListGP (
201+ fixed_noise = False , dtype = dtype , use_octf = use_octf
202+ )
203+ tkwargs = {"device" : self .device , "dtype" : dtype }
212204
213- # test posterior
205+ # test observation_noise
214206 test_x = torch .tensor ([[0.25 ], [0.75 ]], ** tkwargs )
215- posterior = model .posterior (test_x )
207+ posterior = model .posterior (test_x , observation_noise = True )
216208 self .assertIsInstance (posterior , GPyTorchPosterior )
217209 self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
218- if use_octf :
219- # ensure un-transformation is applied
220- submodel = model .models [0 ]
221- p0 = submodel .posterior (test_x )
222- tmp_tf = submodel .outcome_transform
223- del submodel .outcome_transform
224- p0_tf = submodel .posterior (test_x )
225- submodel .outcome_transform = tmp_tf
226- expected_var = tmp_tf .untransform_posterior (p0_tf ).variance
227- self .assertTrue (torch .allclose (p0 .variance , expected_var ))
228210
229- # test output_indices
230- posterior = model .posterior (
231- test_x , output_indices = [0 ], observation_noise = True
232- )
233- self .assertIsInstance (posterior , GPyTorchPosterior )
234- self .assertIsInstance (posterior .distribution , MultivariateNormal )
211+ def test_ModelListGP_fixed_noise (self ) -> None :
235212
236- # test condition_on_observations
213+ for dtype , use_octf in itertools .product (
214+ (torch .float , torch .double ), (False , True )
215+ ):
216+ model = self ._base_test_ModelListGP (
217+ fixed_noise = True , dtype = dtype , use_octf = use_octf
218+ )
219+ tkwargs = {"device" : self .device , "dtype" : dtype }
237220 f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
238221 f_y = torch .rand (2 , 2 , ** tkwargs )
239- noise = 0.1 + 0.1 * torch .rand_like (f_y )
240- cm = model .condition_on_observations (f_x , f_y , noise = noise )
241- self .assertIsInstance (cm , ModelListGP )
242-
243- # test condition_on_observations batched
244- f_x = [torch .rand (3 , 2 , 1 , ** tkwargs ) for _ in range (2 )]
245- f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
246- noise = 0.1 + 0.1 * torch .rand_like (f_y )
247- cm = model .condition_on_observations (f_x , f_y , noise = noise )
248- self .assertIsInstance (cm , ModelListGP )
249-
250- # test condition_on_observations batched (fast fantasies)
251- f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
252- f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
253- noise = 0.1 + 0.1 * torch .rand (2 , 2 , ** tkwargs )
254- cm = model .condition_on_observations (f_x , f_y , noise = noise )
255- self .assertIsInstance (cm , ModelListGP )
256222
257- # test condition_on_observations (incorrect input shape error)
258- with self .assertRaises (BotorchTensorDimensionError ):
259- model .condition_on_observations (
260- f_x , torch .rand (3 , 2 , 3 , ** tkwargs ), noise = noise
261- )
262223 # test condition_on_observations (incorrect noise shape error)
263- f_y = torch .rand (2 , 2 , ** tkwargs )
264224 with self .assertRaises (BotorchTensorDimensionError ):
265225 model .condition_on_observations (
266226 f_x , f_y , noise = torch .rand (2 , 3 , ** tkwargs )
0 commit comments