@@ -65,13 +65,11 @@ def _get_fixed_noise_model_single_output(**tkwargs):
6565
6666class TestMultiTaskGP (BotorchTestCase ):
6767 def test_MultiTaskGP (self ):
68- for double in (False , True ):
69- tkwargs = {
70- "device" : self .device ,
71- "dtype" : torch .double if double else torch .float ,
72- }
68+ for dtype in (torch .float , torch .double ):
69+ tkwargs = {"device" : self .device , "dtype" : dtype }
7370 model = _get_model (** tkwargs )
7471 self .assertIsInstance (model , MultiTaskGP )
72+ self .assertEqual (model .num_outputs , 2 )
7573 self .assertIsInstance (model .likelihood , GaussianLikelihood )
7674 self .assertIsInstance (model .mean_module , ConstantMean )
7775 self .assertIsInstance (model .covar_module , ScaleKernel )
@@ -140,13 +138,11 @@ def test_MultiTaskGP(self):
140138 model .posterior (test_x )
141139
142140 def test_MultiTaskGP_single_output (self ):
143- for double in (False , True ):
144- tkwargs = {
145- "device" : self .device ,
146- "dtype" : torch .double if double else torch .float ,
147- }
141+ for dtype in (torch .float , torch .double ):
142+ tkwargs = {"device" : self .device , "dtype" : dtype }
148143 model = _get_model_single_output (** tkwargs )
149144 self .assertIsInstance (model , MultiTaskGP )
145+ self .assertEqual (model .num_outputs , 1 )
150146 self .assertIsInstance (model .likelihood , GaussianLikelihood )
151147 self .assertIsInstance (model .mean_module , ConstantMean )
152148 self .assertIsInstance (model .covar_module , ScaleKernel )
@@ -180,13 +176,11 @@ def test_MultiTaskGP_single_output(self):
180176
181177class TestFixedNoiseMultiTaskGP (BotorchTestCase ):
182178 def test_FixedNoiseMultiTaskGP (self ):
183- for double in (False , True ):
184- tkwargs = {
185- "device" : self .device ,
186- "dtype" : torch .double if double else torch .float ,
187- }
179+ for dtype in (torch .float , torch .double ):
180+ tkwargs = {"device" : self .device , "dtype" : dtype }
188181 model = _get_fixed_noise_model (** tkwargs )
189182 self .assertIsInstance (model , FixedNoiseMultiTaskGP )
183+ self .assertEqual (model .num_outputs , 2 )
190184 self .assertIsInstance (model .likelihood , FixedNoiseGaussianLikelihood )
191185 self .assertIsInstance (model .mean_module , ConstantMean )
192186 self .assertIsInstance (model .covar_module , ScaleKernel )
@@ -253,13 +247,11 @@ def test_FixedNoiseMultiTaskGP(self):
253247 FixedNoiseMultiTaskGP (train_X , train_Y , train_Yvar , 0 , output_tasks = [2 ])
254248
255249 def test_FixedNoiseMultiTaskGP_single_output (self ):
256- for double in (False , True ):
257- tkwargs = {
258- "device" : self .device ,
259- "dtype" : torch .double if double else torch .float ,
260- }
250+ for dtype in (torch .float , torch .double ):
251+ tkwargs = {"device" : self .device , "dtype" : dtype }
261252 model = _get_fixed_noise_model_single_output (** tkwargs )
262253 self .assertIsInstance (model , FixedNoiseMultiTaskGP )
254+ self .assertEqual (model .num_outputs , 1 )
263255 self .assertIsInstance (model .likelihood , FixedNoiseGaussianLikelihood )
264256 self .assertIsInstance (model .mean_module , ConstantMean )
265257 self .assertIsInstance (model .covar_module , ScaleKernel )
0 commit comments