Skip to content

Commit a675968

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Put input transforms into train mode before converting models (meta-pytorch#1283)
Summary: Pull Request resolved: meta-pytorch#1283 Fixes meta-pytorch#1273 During model construction, input transforms should be in `train` mode (so that they only apply if `transform_on_train` is true). Having the input transforms in eval mode leads to buggy behavior due to `transformed_X` getting transformed when it shouldn't. Reviewed By: Balandat Differential Revision: D37542474 fbshipit-source-id: 7986829612f857995997036f9c48cbeb56d75ceb
1 parent 7ce7c6d commit a675968

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

botorch/models/converter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
152152

153153
# construct the batched GP model
154154
input_transform = getattr(models[0], "input_transform", None)
155+
if input_transform is not None:
156+
input_transform.train()
155157
batch_gp = models[0].__class__(input_transform=input_transform, **kwargs)
156158
adjusted_batch_keys, non_adjusted_batch_keys = _get_adjusted_batch_keys(
157159
batch_state_dict=batch_gp.state_dict(), input_transform=input_transform
@@ -220,6 +222,8 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
220222
"Conversion of MixedSingleTaskGP is currently not supported."
221223
)
222224
input_transform = getattr(batch_model, "input_transform", None)
225+
if input_transform is not None:
226+
input_transform.train()
223227
outcome_transform = getattr(batch_model, "outcome_transform", None)
224228
batch_sd = batch_model.state_dict()
225229

@@ -324,6 +328,8 @@ def batched_multi_output_to_single_output(
324328
"Conversion of models with custom likelihoods is currently unsupported."
325329
)
326330
input_transform = getattr(batch_mo_model, "input_transform", None)
331+
if input_transform is not None:
332+
input_transform.train()
327333
batch_sd = batch_mo_model.state_dict()
328334

329335
# TODO: add support for outcome transforms.

test/models/test_converter.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
batched_to_model_list,
2020
model_list_to_batched,
2121
)
22-
from botorch.models.transforms.input import Normalize
22+
from botorch.models.transforms.input import AppendFeatures, Normalize
2323
from botorch.models.transforms.outcome import Standardize
2424
from botorch.utils.testing import BotorchTestCase
2525
from gpytorch.likelihoods import GaussianLikelihood
@@ -80,6 +80,16 @@ def test_batched_to_model_list(self):
8080
expected_octf.__getattr__(attr_name),
8181
)
8282
)
83+
# test with AppendFeatures
84+
input_tf = AppendFeatures(
85+
feature_set=torch.rand(2, 1, device=self.device, dtype=dtype)
86+
)
87+
batch_gp = SingleTaskGP(
88+
train_X, train_Y, outcome_transform=octf, input_transform=input_tf
89+
).eval()
90+
list_gp = batched_to_model_list(batch_gp)
91+
self.assertIsInstance(list_gp, ModelListGP)
92+
self.assertIsInstance(list_gp.models[0].input_transform, AppendFeatures)
8393

8494
def test_model_list_to_batched(self):
8595
for dtype in (torch.float, torch.double):
@@ -167,6 +177,16 @@ def test_model_list_to_batched(self):
167177
self.assertTrue(
168178
torch.equal(batch_gp.input_transform.bounds, input_tf.bounds)
169179
)
180+
# test with AppendFeatures
181+
input_tf3 = AppendFeatures(
182+
feature_set=torch.rand(2, 1, device=self.device, dtype=dtype)
183+
)
184+
gp1_ = SingleTaskGP(train_X, train_Y1, input_transform=input_tf3)
185+
gp2_ = SingleTaskGP(train_X, train_Y2, input_transform=input_tf3)
186+
list_gp = ModelListGP(gp1_, gp2_).eval()
187+
batch_gp = model_list_to_batched(list_gp)
188+
self.assertIsInstance(batch_gp, SingleTaskGP)
189+
self.assertIsInstance(batch_gp.input_transform, AppendFeatures)
170190
# test different input transforms
171191
input_tf2 = Normalize(
172192
d=2,
@@ -177,7 +197,7 @@ def test_model_list_to_batched(self):
177197
gp1_ = SingleTaskGP(train_X, train_Y1, input_transform=input_tf)
178198
gp2_ = SingleTaskGP(train_X, train_Y2, input_transform=input_tf2)
179199
list_gp = ModelListGP(gp1_, gp2_)
180-
with self.assertRaises(UnsupportedError):
200+
with self.assertRaisesRegex(UnsupportedError, "have the same"):
181201
model_list_to_batched(list_gp)
182202

183203
# test batched input transform
@@ -292,17 +312,26 @@ def test_batched_multi_output_to_single_output(self):
292312
self.assertTrue(
293313
torch.equal(batch_so_model.input_transform.bounds, input_tf.bounds)
294314
)
315+
# test with AppendFeatures
316+
input_tf = AppendFeatures(
317+
feature_set=torch.rand(2, 1, device=self.device, dtype=dtype)
318+
)
319+
batched_mo_model = SingleTaskGP(
320+
train_X, train_Y, input_transform=input_tf
321+
).eval()
322+
batch_so_model = batched_multi_output_to_single_output(batched_mo_model)
323+
self.assertIsInstance(batch_so_model.input_transform, AppendFeatures)
295324

296325
# test batched input transform
297-
input_tf2 = Normalize(
326+
input_tf = Normalize(
298327
d=2,
299328
bounds=torch.tensor(
300329
[[-1.0, -1.0], [1.0, 1.0]], device=self.device, dtype=dtype
301330
),
302331
batch_shape=torch.Size([2]),
303332
)
304-
batched_mo_model = SingleTaskGP(train_X, train_Y, input_transform=input_tf2)
305-
batched_so_model = batched_multi_output_to_single_output(batched_mo_model)
333+
batched_mo_model = SingleTaskGP(train_X, train_Y, input_transform=input_tf)
334+
batch_so_model = batched_multi_output_to_single_output(batched_mo_model)
306335
self.assertIsInstance(batch_so_model.input_transform, Normalize)
307336
self.assertTrue(
308337
torch.equal(batch_so_model.input_transform.bounds, input_tf.bounds)

0 commit comments

Comments
 (0)