Skip to content

Commit a5d74d9

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Condition_on_observations adds data to train_inputs when input_transforms are applied (#2990)
Summary: Pull Request resolved: #2990 When applying input transforms (even with transform_on_train=True), conditioning on new data does not add it to train_inputs. This is unintuitive, as it is not clear why input transforms should disallow the user from adding observations to the training data this way. Since condition_on_observations returns a model (and thus, the user has the flexibility to override the old model or not with the fantasy model), the data should always be added to training. Fixes #2533. Simple notebook on why this doesn't make sense: N7899660 Reviewed By: saitcakmak Differential Revision: D80830741 fbshipit-source-id: 422abfa9d9d6212e52e15fa25e445a73d0b3dc42
1 parent e3710e2 commit a5d74d9

File tree

3 files changed

+116
-6
lines changed

3 files changed

+116
-6
lines changed

botorch/models/gpytorch.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,13 @@ def condition_on_observations(
242242
>>> new_Y = torch.sin(new_X[:, :1]) + torch.cos(new_X[:, 1:])
243243
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
244244
"""
245+
# pass the transformed data to get_fantasy_model below
246+
# (unless we've already transformed if BatchedMultiOutputGPyTorchModel)
245247
X = self.transform_inputs(X)
248+
246249
Yvar = noise
247250
if hasattr(self, "outcome_transform"):
248-
# pass the transformed data to get_fantasy_model below
249-
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
251+
# And do the same for the outcome transform, if it exists.
250252
if not isinstance(self, BatchedMultiOutputGPyTorchModel):
251253
# `noise` is assumed to already be outcome-transformed.
252254
Y, _ = self.outcome_transform(Y=Y, Yvar=Yvar, X=X)
@@ -260,9 +262,25 @@ def condition_on_observations(
260262
if Yvar is not None:
261263
kwargs.update({"noise": Yvar.squeeze(-1)})
262264
# get_fantasy_model will properly copy any existing outcome transforms
263-
# (since it deepcopies the original model)
264-
265-
return self.get_fantasy_model(inputs=X, targets=Y, **kwargs)
265+
# (since it deepcopies the original model))
266+
fantasy_model = self.get_fantasy_model(inputs=X, targets=Y, **kwargs)
267+
268+
# If we use an input transform, the fantasized data will not get added to
269+
# the training data by default. We need to manually add it.
270+
if hasattr(fantasy_model, "input_transform"):
271+
# Broadcast tensors to compatible shape before concatenating
272+
expand_shape = torch.broadcast_shapes(
273+
X.shape[:-2], fantasy_model._original_train_inputs.shape[:-2]
274+
)
275+
X_expanded = X.expand(expand_shape + X.shape[-2:])
276+
orig_expanded = fantasy_model._original_train_inputs.expand(
277+
expand_shape + fantasy_model._original_train_inputs.shape[-2:]
278+
)
279+
fantasy_model._original_train_inputs = torch.cat(
280+
[orig_expanded, X_expanded],
281+
dim=-2,
282+
).detach()
283+
return fantasy_model
266284

267285

268286
# pyre-fixme[13]: uninitialized attributes _num_outputs, _input_batch_shape,

test/models/test_fully_bayesian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ def test_construct_inputs(self) -> None:
792792
else:
793793
self.assertTrue(Yvar.equal(data_dict["train_Yvar"]))
794794

795-
def test_fbstgp_condition_on_observations(self) -> None:
795+
def test_condition_on_observations(self) -> None:
796796
# The following conditioned data shapes should work (output describes):
797797
# training data shape after cond(batch shape in output is req. in gpytorch)
798798
# X: num_models x n x d, Y: num_models x n x d --> num_models x n x d

test/models/test_gpytorch.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,3 +734,95 @@ def test_condition_on_observations_input_transform_consistency(self):
734734
conditioned_model.train_inputs[0],
735735
expected_combined_inputs,
736736
)
737+
738+
def test_condition_on_observations_train_input_shapes(self):
739+
"""Comprehensive test for condition_on_observations functionality.
740+
741+
Tests input transform consistency, train/eval mode stability,
742+
different transform settings, and batch shape handling.
743+
"""
744+
for dtype in (torch.float, torch.double):
745+
tkwargs = {"device": self.device, "dtype": dtype}
746+
747+
# Test 1: Train/eval mode stability with transform_on_train=True
748+
train_X = torch.tensor([[0.0], [1.0]], **tkwargs)
749+
train_Y = torch.tensor([[1.0], [2.0]], **tkwargs)
750+
input_transform = SimpleInputTransform(transform_on_train=True)
751+
model = SimpleGPyTorchModel(
752+
train_X, train_Y, input_transform=input_transform
753+
)
754+
755+
new_X = torch.tensor([[0.5]], **tkwargs)
756+
new_Y = torch.tensor([[1.5]], **tkwargs)
757+
_ = model.posterior(train_X)
758+
conditioned_model = model.condition_on_observations(new_X, new_Y)
759+
760+
# Verify conditioned observations persist across train/eval modes
761+
conditioned_model.eval()
762+
self.assertEqual(conditioned_model.train_targets.shape[0], 3)
763+
conditioned_model.train()
764+
self.assertEqual(conditioned_model.train_targets.shape[0], 3)
765+
self.assertEqual(conditioned_model.train_inputs[0].shape[0], 3)
766+
767+
# Test 2: Transform behavior with transform_on_train=False
768+
model2 = SimpleGPyTorchModel(
769+
train_X,
770+
train_Y,
771+
input_transform=SimpleInputTransform(transform_on_train=False),
772+
)
773+
_ = model2.posterior(train_X)
774+
conditioned_model2 = model2.condition_on_observations(new_X, new_Y)
775+
self.assertEqual(conditioned_model2.train_targets.shape[0], 3)
776+
777+
# Verify model can make predictions after conditioning
778+
test_X = torch.tensor([[0.25]], **tkwargs)
779+
posterior = conditioned_model2.posterior(test_X)
780+
self.assertEqual(posterior.mean.shape, torch.Size([1, 1]))
781+
782+
# Test 3: Batch shape handling and broadcasting (double precision only)
783+
tkwargs = {"device": self.device, "dtype": torch.double}
784+
785+
# Same ndim - should update _original_train_inputs
786+
train_X = torch.rand(2, 2, **tkwargs)
787+
train_Y = torch.rand(2, 1, **tkwargs)
788+
model = SimpleGPyTorchModel(
789+
train_X, train_Y, input_transform=SimpleInputTransform(True)
790+
)
791+
_ = model.posterior(train_X)
792+
793+
original_size = model._original_train_inputs.shape[0]
794+
fantasy_model = model.condition_on_observations(
795+
torch.rand(1, 2, **tkwargs), torch.rand(1, 1, **tkwargs)
796+
)
797+
self.assertEqual(
798+
fantasy_model._original_train_inputs.shape[0], original_size + 1
799+
)
800+
801+
# Different ndim - should NOT update _original_train_inputs
802+
original_size = model._original_train_inputs.shape[0]
803+
fantasy_model = model.condition_on_observations(
804+
torch.rand(3, 2, 2, **tkwargs), torch.rand(3, 2, 1, **tkwargs)
805+
)
806+
807+
# NOTE expected behavior is expand (2, 2) & (3, 2, 2) is expanded
808+
# and then concatenated along dimension -2 --> (3, 4, 2)
809+
self.assertEqual(
810+
fantasy_model._original_train_inputs.shape, torch.Size([3, 4, 2])
811+
)
812+
813+
# Test 4: Fantasy model behavior
814+
model2 = SimpleGPyTorchModel(
815+
train_X, train_Y, input_transform=SimpleInputTransform(True)
816+
)
817+
_ = model2.posterior(train_X)
818+
original_size = model2._original_train_inputs.shape[0]
819+
820+
fantasy_model = model2.condition_on_observations(
821+
torch.rand(1, 2, **tkwargs), torch.rand(1, 1, **tkwargs)
822+
)
823+
824+
# Fantasy model gets data, original model does not
825+
self.assertEqual(
826+
fantasy_model._original_train_inputs.shape[0], original_size + 1
827+
)
828+
self.assertEqual(model2._original_train_inputs.shape[0], original_size)

0 commit comments

Comments
 (0)