@@ -75,120 +75,114 @@ class TestModelListGP(BotorchTestCase):
75
75
def _base_test_ModelListGP (
76
76
self , fixed_noise : bool , dtype , use_octf : bool
77
77
) -> ModelListGP :
78
- # this is to make review easier -- will be removed in the next
79
- # commit in the stack and never landed
80
- unneccessary_condition_for_indentation_remove_me = True
81
- if unneccessary_condition_for_indentation_remove_me :
82
- tkwargs = {"device" : self .device , "dtype" : dtype }
83
- model = _get_model (fixed_noise = fixed_noise , use_octf = use_octf , ** tkwargs )
84
- self .assertIsInstance (model , ModelListGP )
85
- self .assertIsInstance (model .likelihood , LikelihoodList )
86
- for m in model .models :
87
- self .assertIsInstance (m .mean_module , ConstantMean )
88
- self .assertIsInstance (m .covar_module , ScaleKernel )
89
- matern_kernel = m .covar_module .base_kernel
90
- self .assertIsInstance (matern_kernel , MaternKernel )
91
- self .assertIsInstance (matern_kernel .lengthscale_prior , GammaPrior )
92
- if use_octf :
93
- self .assertIsInstance (m .outcome_transform , Standardize )
94
-
95
- # test constructing likelihood wrapper
96
- mll = SumMarginalLogLikelihood (model .likelihood , model )
97
- for mll_ in mll .mlls :
98
- self .assertIsInstance (mll_ , ExactMarginalLogLikelihood )
99
-
100
- # test model fitting (sequential)
101
- with warnings .catch_warnings ():
102
- warnings .filterwarnings ("ignore" , category = OptimizationWarning )
103
- mll = fit_gpytorch_mll (
104
- mll , optimizer_kwargs = {"options" : {"maxiter" : 1 }}, max_attempts = 1
105
- )
106
- with warnings .catch_warnings ():
107
- warnings .filterwarnings ("ignore" , category = OptimizationWarning )
108
- # test model fitting (joint)
109
- mll = fit_gpytorch_mll (
110
- mll ,
111
- optimizer_kwargs = {"options" : {"maxiter" : 1 }},
112
- max_attempts = 1 ,
113
- sequential = False ,
114
- )
115
-
116
- # test subset outputs
117
- subset_model = model .subset_output ([1 ])
118
- self .assertIsInstance (subset_model , ModelListGP )
119
- self .assertEqual (len (subset_model .models ), 1 )
120
- sd_subset = subset_model .models [0 ].state_dict ()
121
- sd = model .models [1 ].state_dict ()
122
- self .assertTrue (set (sd_subset .keys ()) == set (sd .keys ()))
123
- self .assertTrue (all (torch .equal (v , sd [k ]) for k , v in sd_subset .items ()))
124
-
125
- # test posterior
126
- test_x = torch .tensor ([[0.25 ], [0.75 ]], ** tkwargs )
127
- posterior = model .posterior (test_x )
128
- self .assertIsInstance (posterior , GPyTorchPosterior )
129
- self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
78
+ tkwargs = {"device" : self .device , "dtype" : dtype }
79
+ model = _get_model (fixed_noise = fixed_noise , use_octf = use_octf , ** tkwargs )
80
+ self .assertIsInstance (model , ModelListGP )
81
+ self .assertIsInstance (model .likelihood , LikelihoodList )
82
+ for m in model .models :
83
+ self .assertIsInstance (m .mean_module , ConstantMean )
84
+ self .assertIsInstance (m .covar_module , ScaleKernel )
85
+ matern_kernel = m .covar_module .base_kernel
86
+ self .assertIsInstance (matern_kernel , MaternKernel )
87
+ self .assertIsInstance (matern_kernel .lengthscale_prior , GammaPrior )
130
88
if use_octf :
131
- # ensure un-transformation is applied
132
- submodel = model .models [0 ]
133
- p0 = submodel .posterior (test_x )
134
- tmp_tf = submodel .outcome_transform
135
- del submodel .outcome_transform
136
- p0_tf = submodel .posterior (test_x )
137
- submodel .outcome_transform = tmp_tf
138
- expected_var = tmp_tf .untransform_posterior (p0_tf ).variance
139
- self .assertTrue (torch .allclose (p0 .variance , expected_var ))
140
-
141
- # test output_indices
142
- posterior = model .posterior (
143
- test_x , output_indices = [0 ], observation_noise = True
89
+ self .assertIsInstance (m .outcome_transform , Standardize )
90
+
91
+ # test constructing likelihood wrapper
92
+ mll = SumMarginalLogLikelihood (model .likelihood , model )
93
+ for mll_ in mll .mlls :
94
+ self .assertIsInstance (mll_ , ExactMarginalLogLikelihood )
95
+
96
+ # test model fitting (sequential)
97
+ with warnings .catch_warnings ():
98
+ warnings .filterwarnings ("ignore" , category = OptimizationWarning )
99
+ mll = fit_gpytorch_mll (
100
+ mll , optimizer_kwargs = {"options" : {"maxiter" : 1 }}, max_attempts = 1
101
+ )
102
+ with warnings .catch_warnings ():
103
+ warnings .filterwarnings ("ignore" , category = OptimizationWarning )
104
+ # test model fitting (joint)
105
+ mll = fit_gpytorch_mll (
106
+ mll ,
107
+ optimizer_kwargs = {"options" : {"maxiter" : 1 }},
108
+ max_attempts = 1 ,
109
+ sequential = False ,
144
110
)
145
- self .assertIsInstance (posterior , GPyTorchPosterior )
146
- self .assertIsInstance (posterior .distribution , MultivariateNormal )
147
111
148
- # test condition_on_observations
149
- f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
150
- f_y = torch .rand (2 , 2 , ** tkwargs )
151
- if fixed_noise :
152
- noise = 0.1 + 0.1 * torch .rand_like (f_y )
153
- cond_kwargs = {"noise" : noise }
154
- else :
155
- cond_kwargs = {}
156
- cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
157
- self .assertIsInstance (cm , ModelListGP )
158
-
159
- # test condition_on_observations batched
160
- f_x = [torch .rand (3 , 2 , 1 , ** tkwargs ) for _ in range (2 )]
161
- f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
162
- cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
163
- self .assertIsInstance (cm , ModelListGP )
164
-
165
- # test condition_on_observations batched (fast fantasies)
166
- f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
167
- f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
168
- cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
169
- self .assertIsInstance (cm , ModelListGP )
112
+ # test subset outputs
113
+ subset_model = model .subset_output ([1 ])
114
+ self .assertIsInstance (subset_model , ModelListGP )
115
+ self .assertEqual (len (subset_model .models ), 1 )
116
+ sd_subset = subset_model .models [0 ].state_dict ()
117
+ sd = model .models [1 ].state_dict ()
118
+ self .assertTrue (set (sd_subset .keys ()) == set (sd .keys ()))
119
+ self .assertTrue (all (torch .equal (v , sd [k ]) for k , v in sd_subset .items ()))
170
120
171
- # test condition_on_observations (incorrect input shape error)
172
- with self .assertRaises (BotorchTensorDimensionError ):
173
- model .condition_on_observations (
174
- f_x , torch .rand (3 , 2 , 3 , ** tkwargs ), ** cond_kwargs
175
- )
121
+ # test posterior
122
+ test_x = torch .tensor ([[0.25 ], [0.75 ]], ** tkwargs )
123
+ posterior = model .posterior (test_x )
124
+ self .assertIsInstance (posterior , GPyTorchPosterior )
125
+ self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
126
+ if use_octf :
127
+ # ensure un-transformation is applied
128
+ submodel = model .models [0 ]
129
+ p0 = submodel .posterior (test_x )
130
+ tmp_tf = submodel .outcome_transform
131
+ del submodel .outcome_transform
132
+ p0_tf = submodel .posterior (test_x )
133
+ submodel .outcome_transform = tmp_tf
134
+ expected_var = tmp_tf .untransform_posterior (p0_tf ).variance
135
+ self .assertTrue (torch .allclose (p0 .variance , expected_var ))
136
+
137
+ # test output_indices
138
+ posterior = model .posterior (test_x , output_indices = [0 ], observation_noise = True )
139
+ self .assertIsInstance (posterior , GPyTorchPosterior )
140
+ self .assertIsInstance (posterior .distribution , MultivariateNormal )
176
141
177
- # test X having wrong size
178
- with self .assertRaises (AssertionError ):
179
- model .condition_on_observations (f_x [:1 ], f_y )
142
+ # test condition_on_observations
143
+ f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
144
+ f_y = torch .rand (2 , 2 , ** tkwargs )
145
+ if fixed_noise :
146
+ noise = 0.1 + 0.1 * torch .rand_like (f_y )
147
+ cond_kwargs = {"noise" : noise }
148
+ else :
149
+ cond_kwargs = {}
150
+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
151
+ self .assertIsInstance (cm , ModelListGP )
152
+
153
+ # test condition_on_observations batched
154
+ f_x = [torch .rand (3 , 2 , 1 , ** tkwargs ) for _ in range (2 )]
155
+ f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
156
+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
157
+ self .assertIsInstance (cm , ModelListGP )
158
+
159
+ # test condition_on_observations batched (fast fantasies)
160
+ f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
161
+ f_y = torch .rand (3 , 2 , 2 , ** tkwargs )
162
+ cm = model .condition_on_observations (f_x , f_y , ** cond_kwargs )
163
+ self .assertIsInstance (cm , ModelListGP )
164
+
165
+ # test condition_on_observations (incorrect input shape error)
166
+ with self .assertRaises (BotorchTensorDimensionError ):
167
+ model .condition_on_observations (
168
+ f_x , torch .rand (3 , 2 , 3 , ** tkwargs ), ** cond_kwargs
169
+ )
180
170
181
- # test posterior transform
182
- X = torch .rand (3 , 1 , ** tkwargs )
183
- weights = torch .tensor ([1 , 2 ], ** tkwargs )
184
- post_tf = ScalarizedPosteriorTransform (weights = weights )
185
- posterior_tf = model .posterior (X , posterior_transform = post_tf )
186
- self .assertTrue (
187
- torch .allclose (
188
- posterior_tf .mean ,
189
- model .posterior (X ).mean @ weights .unsqueeze (- 1 ),
190
- )
171
+ # test X having wrong size
172
+ with self .assertRaises (AssertionError ):
173
+ model .condition_on_observations (f_x [:1 ], f_y )
174
+
175
+ # test posterior transform
176
+ X = torch .rand (3 , 1 , ** tkwargs )
177
+ weights = torch .tensor ([1 , 2 ], ** tkwargs )
178
+ post_tf = ScalarizedPosteriorTransform (weights = weights )
179
+ posterior_tf = model .posterior (X , posterior_transform = post_tf )
180
+ self .assertTrue (
181
+ torch .allclose (
182
+ posterior_tf .mean ,
183
+ model .posterior (X ).mean @ weights .unsqueeze (- 1 ),
191
184
)
185
+ )
192
186
193
187
return model
194
188
0 commit comments