19
19
batched_to_model_list ,
20
20
model_list_to_batched ,
21
21
)
22
- from botorch .models .transforms .input import Normalize
22
+ from botorch .models .transforms .input import AppendFeatures , Normalize
23
23
from botorch .models .transforms .outcome import Standardize
24
24
from botorch .utils .testing import BotorchTestCase
25
25
from gpytorch .likelihoods import GaussianLikelihood
@@ -80,6 +80,16 @@ def test_batched_to_model_list(self):
80
80
expected_octf .__getattr__ (attr_name ),
81
81
)
82
82
)
83
+ # test with AppendFeatures
84
+ input_tf = AppendFeatures (
85
+ feature_set = torch .rand (2 , 1 , device = self .device , dtype = dtype )
86
+ )
87
+ batch_gp = SingleTaskGP (
88
+ train_X , train_Y , outcome_transform = octf , input_transform = input_tf
89
+ ).eval ()
90
+ list_gp = batched_to_model_list (batch_gp )
91
+ self .assertIsInstance (list_gp , ModelListGP )
92
+ self .assertIsInstance (list_gp .models [0 ].input_transform , AppendFeatures )
83
93
84
94
def test_model_list_to_batched (self ):
85
95
for dtype in (torch .float , torch .double ):
@@ -167,6 +177,16 @@ def test_model_list_to_batched(self):
167
177
self .assertTrue (
168
178
torch .equal (batch_gp .input_transform .bounds , input_tf .bounds )
169
179
)
180
+ # test with AppendFeatures
181
+ input_tf3 = AppendFeatures (
182
+ feature_set = torch .rand (2 , 1 , device = self .device , dtype = dtype )
183
+ )
184
+ gp1_ = SingleTaskGP (train_X , train_Y1 , input_transform = input_tf3 )
185
+ gp2_ = SingleTaskGP (train_X , train_Y2 , input_transform = input_tf3 )
186
+ list_gp = ModelListGP (gp1_ , gp2_ ).eval ()
187
+ batch_gp = model_list_to_batched (list_gp )
188
+ self .assertIsInstance (batch_gp , SingleTaskGP )
189
+ self .assertIsInstance (batch_gp .input_transform , AppendFeatures )
170
190
# test different input transforms
171
191
input_tf2 = Normalize (
172
192
d = 2 ,
@@ -177,7 +197,7 @@ def test_model_list_to_batched(self):
177
197
gp1_ = SingleTaskGP (train_X , train_Y1 , input_transform = input_tf )
178
198
gp2_ = SingleTaskGP (train_X , train_Y2 , input_transform = input_tf2 )
179
199
list_gp = ModelListGP (gp1_ , gp2_ )
180
- with self .assertRaises (UnsupportedError ):
200
+ with self .assertRaisesRegex (UnsupportedError , "have the same" ):
181
201
model_list_to_batched (list_gp )
182
202
183
203
# test batched input transform
@@ -292,17 +312,26 @@ def test_batched_multi_output_to_single_output(self):
292
312
self .assertTrue (
293
313
torch .equal (batch_so_model .input_transform .bounds , input_tf .bounds )
294
314
)
315
+ # test with AppendFeatures
316
+ input_tf = AppendFeatures (
317
+ feature_set = torch .rand (2 , 1 , device = self .device , dtype = dtype )
318
+ )
319
+ batched_mo_model = SingleTaskGP (
320
+ train_X , train_Y , input_transform = input_tf
321
+ ).eval ()
322
+ batch_so_model = batched_multi_output_to_single_output (batched_mo_model )
323
+ self .assertIsInstance (batch_so_model .input_transform , AppendFeatures )
295
324
296
325
# test batched input transform
297
- input_tf2 = Normalize (
326
+ input_tf = Normalize (
298
327
d = 2 ,
299
328
bounds = torch .tensor (
300
329
[[- 1.0 , - 1.0 ], [1.0 , 1.0 ]], device = self .device , dtype = dtype
301
330
),
302
331
batch_shape = torch .Size ([2 ]),
303
332
)
304
- batched_mo_model = SingleTaskGP (train_X , train_Y , input_transform = input_tf2 )
305
- batched_so_model = batched_multi_output_to_single_output (batched_mo_model )
333
+ batched_mo_model = SingleTaskGP (train_X , train_Y , input_transform = input_tf )
334
+ batch_so_model = batched_multi_output_to_single_output (batched_mo_model )
306
335
self .assertIsInstance (batch_so_model .input_transform , Normalize )
307
336
self .assertTrue (
308
337
torch .equal (batch_so_model .input_transform .bounds , input_tf .bounds )
0 commit comments