@@ -75,120 +75,114 @@ class TestModelListGP(BotorchTestCase):
7575 def _base_test_ModelListGP (
7676 self , fixed_noise : bool , dtype , use_octf : bool
7777 ) -> ModelListGP :
78- # this is to make review easier -- will be removed in the next
79- # commit in the stack and never landed
80- unneccessary_condition_for_indentation_remove_me = True
81- if unneccessary_condition_for_indentation_remove_me :
82- tkwargs = {"device" : self .device , "dtype" : dtype }
83- model = _get_model (fixed_noise = fixed_noise , use_octf = use_octf , ** tkwargs )
84- self .assertIsInstance (model , ModelListGP )
85- self .assertIsInstance (model .likelihood , LikelihoodList )
86- for m in model .models :
87- self .assertIsInstance (m .mean_module , ConstantMean )
88- self .assertIsInstance (m .covar_module , ScaleKernel )
89- matern_kernel = m .covar_module .base_kernel
90- self .assertIsInstance (matern_kernel , MaternKernel )
91- self .assertIsInstance (matern_kernel .lengthscale_prior , GammaPrior )
92- if use_octf :
93- self .assertIsInstance (m .outcome_transform , Standardize )
94-
95- # test constructing likelihood wrapper
96- mll = SumMarginalLogLikelihood (model .likelihood , model )
97- for mll_ in mll .mlls :
98- self .assertIsInstance (mll_ , ExactMarginalLogLikelihood )
99-
100- # test model fitting (sequential)
101- with warnings .catch_warnings ():
102- warnings .filterwarnings ("ignore" , category = OptimizationWarning )
103- mll = fit_gpytorch_mll (
104- mll , optimizer_kwargs = {"options" : {"maxiter" : 1 }}, max_attempts = 1
105- )
106- with warnings .catch_warnings ():
107- warnings .filterwarnings ("ignore" , category = OptimizationWarning )
108- # test model fitting (joint)
109- mll = fit_gpytorch_mll (
110- mll ,
111- optimizer_kwargs = {"options" : {"maxiter" : 1 }},
112- max_attempts = 1 ,
113- sequential = False ,
114- )
115-
116- # test subset outputs
117- subset_model = model .subset_output ([1 ])
118- self .assertIsInstance (subset_model , ModelListGP )
119- self .assertEqual (len (subset_model .models ), 1 )
120- sd_subset = subset_model .models [0 ].state_dict ()
121- sd = model .models [1 ].state_dict ()
122- self .assertTrue (set (sd_subset .keys ()) == set (sd .keys ()))
123- self .assertTrue (all (torch .equal (v , sd [k ]) for k , v in sd_subset .items ()))
124-
125- # test posterior
126- test_x = torch .tensor ([[0.25 ], [0.75 ]], ** tkwargs )
127- posterior = model .posterior (test_x )
128- self .assertIsInstance (posterior , GPyTorchPosterior )
129- self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
78+ tkwargs = {"device" : self .device , "dtype" : dtype }
79+ model = _get_model (fixed_noise = fixed_noise , use_octf = use_octf , ** tkwargs )
80+ self .assertIsInstance (model , ModelListGP )
81+ self .assertIsInstance (model .likelihood , LikelihoodList )
82+ for m in model .models :
83+ self .assertIsInstance (m .mean_module , ConstantMean )
84+ self .assertIsInstance (m .covar_module , ScaleKernel )
85+ matern_kernel = m .covar_module .base_kernel
86+ self .assertIsInstance (matern_kernel , MaternKernel )
87+ self .assertIsInstance (matern_kernel .lengthscale_prior , GammaPrior )
13088 if use_octf :
131- # ensure un-transformation is applied
132- submodel = model .models [0 ]
133- p0 = submodel .posterior (test_x )
134- tmp_tf = submodel .outcome_transform
135- del submodel .outcome_transform
136- p0_tf = submodel .posterior (test_x )
137- submodel .outcome_transform = tmp_tf
138- expected_var = tmp_tf .untransform_posterior (p0_tf ).variance
139- self .assertTrue (torch .allclose (p0 .variance , expected_var ))
140-
141- # test output_indices
142- posterior = model .posterior (
143- test_x , output_indices = [0 ], observation_noise = True
89+ self .assertIsInstance (m .outcome_transform , Standardize )
90+
91+ # test constructing likelihood wrapper
92+ mll = SumMarginalLogLikelihood (model .likelihood , model )
93+ for mll_ in mll .mlls :
94+ self .assertIsInstance (mll_ , ExactMarginalLogLikelihood )
95+
96+ # test model fitting (sequential)
97+ with warnings .catch_warnings ():
98+ warnings .filterwarnings ("ignore" , category = OptimizationWarning )
99+ mll = fit_gpytorch_mll (
100+ mll , optimizer_kwargs = {"options" : {"maxiter" : 1 }}, max_attempts = 1
101+ )
102+ with warnings .catch_warnings ():
103+ warnings .filterwarnings ("ignore" , category = OptimizationWarning )
104+ # test model fitting (joint)
105+ mll = fit_gpytorch_mll (
106+ mll ,
107+ optimizer_kwargs = {"options" : {"maxiter" : 1 }},
108+ max_attempts = 1 ,
109+ sequential = False ,
144110 )
145- self .assertIsInstance (posterior , GPyTorchPosterior )
146- self .assertIsInstance (posterior .distribution , MultivariateNormal )
147111
148- # test condition_on_observations
149- f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
150- f_y = torch .rand (2 , 2 , ** tkwargs )
151- if fixed_noise :
152- noise = 0.1 + 0.1 * torch .rand_like (f_y )
153- cond_kwargs = {"noise" : noise }
154- else :
155- cond_kwargs = {}
156- cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
157- self .assertIsInstance (cm , ModelListGP )
158-
159- # test condition_on_observations batched
160- f_x = [torch .rand (3 , 2 , 1 , ** tkwargs ) for _ in range (2 )]
161- f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
162- cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
163- self .assertIsInstance (cm , ModelListGP )
164-
165- # test condition_on_observations batched (fast fantasies)
166- f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
167- f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
168- cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
169- self .assertIsInstance (cm , ModelListGP )
112+ # test subset outputs
113+ subset_model = model .subset_output ([1 ])
114+ self .assertIsInstance (subset_model , ModelListGP )
115+ self .assertEqual (len (subset_model .models ), 1 )
116+ sd_subset = subset_model .models [0 ].state_dict ()
117+ sd = model .models [1 ].state_dict ()
118+ self .assertTrue (set (sd_subset .keys ()) == set (sd .keys ()))
119+ self .assertTrue (all (torch .equal (v , sd [k ]) for k , v in sd_subset .items ()))
170120
171- # test condition_on_observations (incorrect input shape error)
172- with self .assertRaises (BotorchTensorDimensionError ):
173- model .condition_on_observations (
174- f_x , torch .rand (3 , 2 , 3 , ** tkwargs ), ** cond_kwargs
175- )
121+ # test posterior
122+ test_x = torch .tensor ([[0.25 ], [0.75 ]], ** tkwargs )
123+ posterior = model .posterior (test_x )
124+ self .assertIsInstance (posterior , GPyTorchPosterior )
125+ self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
126+ if use_octf :
127+ # ensure un-transformation is applied
128+ submodel = model .models [0 ]
129+ p0 = submodel .posterior (test_x )
130+ tmp_tf = submodel .outcome_transform
131+ del submodel .outcome_transform
132+ p0_tf = submodel .posterior (test_x )
133+ submodel .outcome_transform = tmp_tf
134+ expected_var = tmp_tf .untransform_posterior (p0_tf ).variance
135+ self .assertTrue (torch .allclose (p0 .variance , expected_var ))
136+
137+ # test output_indices
138+ posterior = model .posterior (test_x , output_indices = [0 ], observation_noise = True )
139+ self .assertIsInstance (posterior , GPyTorchPosterior )
140+ self .assertIsInstance (posterior .distribution , MultivariateNormal )
176141
177- # test X having wrong size
178- with self .assertRaises (AssertionError ):
179- model .condition_on_observations (f_x [:1 ], f_y )
142+ # test condition_on_observations
143+ f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
144+ f_y = torch .rand (2 , 2 , ** tkwargs )
145+ if fixed_noise :
146+ noise = 0.1 + 0.1 * torch .rand_like (f_y )
147+ cond_kwargs = {"noise" : noise }
148+ else :
149+ cond_kwargs = {}
150+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
151+ self .assertIsInstance (cm , ModelListGP )
152+
153+ # test condition_on_observations batched
154+ f_x = [torch .rand (3 , 2 , 1 , ** tkwargs ) for _ in range (2 )]
155+ f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
156+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
157+ self .assertIsInstance (cm , ModelListGP )
158+
159+ # test condition_on_observations batched (fast fantasies)
160+ f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
161+ f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
162+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
163+ self .assertIsInstance (cm , ModelListGP )
164+
165+ # test condition_on_observations (incorrect input shape error)
166+ with self .assertRaises (BotorchTensorDimensionError ):
167+ model .condition_on_observations (
168+ f_x , torch .rand (3 , 2 , 3 , ** tkwargs ), ** cond_kwargs
169+ )
180170
181- # test posterior transform
182- X = torch .rand (3 , 1 , ** tkwargs )
183- weights = torch .tensor ([1 , 2 ], ** tkwargs )
184- post_tf = ScalarizedPosteriorTransform (weights = weights )
185- posterior_tf = model .posterior (X , posterior_transform = post_tf )
186- self .assertTrue (
187- torch .allclose (
188- posterior_tf .mean ,
189- model .posterior (X ).mean @ weights .unsqueeze (- 1 ),
190- )
171+ # test X having wrong size
172+ with self .assertRaises (AssertionError ):
173+ model .condition_on_observations (f_x [:1 ], f_y )
174+
175+ # test posterior transform
176+ X = torch .rand (3 , 1 , ** tkwargs )
177+ weights = torch .tensor ([1 , 2 ], ** tkwargs )
178+ post_tf = ScalarizedPosteriorTransform (weights = weights )
179+ posterior_tf = model .posterior (X , posterior_transform = post_tf )
180+ self .assertTrue (
181+ torch .allclose (
182+ posterior_tf .mean ,
183+ model .posterior (X ).mean @ weights .unsqueeze (- 1 ),
191184 )
185+ )
192186
193187 return model
194188
0 commit comments