Skip to content

Commit 3ec2ab0

Browse files
jduerholtmeta-codesync[bot]
authored andcommitted
Fix bug in condition_on_observation (#3034)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/meta-pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation Fix the bug outlined here: #3033 Tests are not yet adapted, first, I want to get your opinion on the fix ;) ### Have you read the [Contributing Guidelines on pull requests](https://github.com/meta-pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #3034 Test Plan: Unit tests, not yet adapted. Reviewed By: hvarfner Differential Revision: D83766789 Pulled By: saitcakmak fbshipit-source-id: 787d885d617e398fae22c2447ba76156b75f5d46
1 parent 26e96d4 commit 3ec2ab0

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

botorch/models/gpytorch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def condition_on_observations(
244244
"""
245245
# pass the transformed data to get_fantasy_model below
246246
# (unless we've already transformed if BatchedMultiOutputGPyTorchModel)
247+
X_original = X.clone()
247248
X = self.transform_inputs(X)
248249

249250
Yvar = noise
@@ -270,9 +271,9 @@ def condition_on_observations(
270271
if hasattr(fantasy_model, "input_transform"):
271272
# Broadcast tensors to compatible shape before concatenating
272273
expand_shape = torch.broadcast_shapes(
273-
X.shape[:-2], fantasy_model._original_train_inputs.shape[:-2]
274+
X_original.shape[:-2], fantasy_model._original_train_inputs.shape[:-2]
274275
)
275-
X_expanded = X.expand(expand_shape + X.shape[-2:])
276+
X_expanded = X_original.expand(expand_shape + X_original.shape[-2:])
276277
orig_expanded = fantasy_model._original_train_inputs.expand(
277278
expand_shape + fantasy_model._original_train_inputs.shape[-2:]
278279
)

test/models/test_gpytorch.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import itertools
88
import warnings
9+
from functools import partial
910

1011
import torch
1112
from botorch.acquisition.objective import ScalarizedPosteriorTransform
@@ -24,7 +25,11 @@
2425
from botorch.models.model import FantasizeMixin
2526
from botorch.models.multitask import MultiTaskGP
2627
from botorch.models.transforms import Standardize
27-
from botorch.models.transforms.input import ChainedInputTransform, InputTransform
28+
from botorch.models.transforms.input import (
29+
ChainedInputTransform,
30+
InputTransform,
31+
NumericToCategoricalEncoding,
32+
)
2833
from botorch.models.utils import fantasize
2934
from botorch.posteriors.gpytorch import GPyTorchPosterior
3035
from botorch.sampling.normal import SobolQMCNormalSampler
@@ -39,6 +44,8 @@
3944
from gpytorch.settings import trace_mode
4045
from torch import Tensor
4146

47+
from torch.nn.functional import one_hot
48+
4249

4350
class SimpleInputTransform(InputTransform, torch.nn.Module):
4451
def __init__(self, transform_on_train: bool) -> None:
@@ -691,6 +698,43 @@ def test_condition_on_observations_model_list(self):
691698
X=torch.rand(2, 1, **tkwargs), Y=torch.rand(2, 2, **tkwargs)
692699
)
693700

701+
def test_condition_on_observations_input_transform_shape_manipulation(self):
702+
for dtype in (torch.float, torch.double):
703+
tkwargs = {"device": self.device, "dtype": dtype}
704+
705+
# Create data
706+
X = torch.rand(12, 2, **tkwargs) * 2
707+
Y = 1 - (X - 0.5).norm(dim=-1, keepdim=True)
708+
Y += 0.1 * torch.rand_like(Y)
709+
# Add a categorical feature
710+
new_col = torch.randint(0, 3, (X.shape[0], 1), **tkwargs)
711+
X = torch.cat([X, new_col], dim=1)
712+
713+
train_X = X[:10]
714+
train_Y = Y[:10]
715+
716+
condition_X = X[10:]
717+
condition_Y = Y[10:]
718+
719+
# setup transform and model
720+
input_transform = NumericToCategoricalEncoding(
721+
dim=3,
722+
categorical_features={2: 3},
723+
encoders={2: partial(one_hot, num_classes=3)},
724+
)
725+
726+
model = SimpleGPyTorchModel(
727+
train_X, train_Y, input_transform=input_transform
728+
)
729+
model.eval()
730+
_ = model.posterior(train_X)
731+
732+
conditioned_model = model.condition_on_observations(
733+
condition_X, condition_Y
734+
)
735+
self.assertAllClose(conditioned_model._original_train_inputs, X)
736+
self.assertAllClose(conditioned_model.train_inputs[0], input_transform(X))
737+
694738
def test_condition_on_observations_input_transform_consistency(self):
695739
"""Test that input transforms are applied consistently in
696740
condition_on_observations.

0 commit comments

Comments
 (0)