Skip to content

Commit 324b7e2

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add a skip_expand option to AppendFeatures (#1344)
Summary: Pull Request resolved: #1344 Adds a `skip_expand` option to `AppendFeatures` transform, which skips the `expand` operation in the `forward` pass, allowing it to be mixed with an `InputPerturbation` transform without multiplying the size of the output tensor. When used together, each perturbation and each appended feature will appear only once in the output tensor per q-batch. Reviewed By: danielrjiang Differential Revision: D38594609 fbshipit-source-id: 8cf1a8cbfb1be6faf3ba7a524d784f6488f461f0
1 parent 2444d58 commit 324b7e2

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

botorch/models/transforms/input.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,7 @@ class AppendFeatures(InputTransform, Module):
965965
def __init__(
966966
self,
967967
feature_set: Tensor,
968+
skip_expand: bool = False,
968969
transform_on_train: bool = False,
969970
transform_on_eval: bool = True,
970971
transform_on_fantasize: bool = False,
@@ -974,6 +975,10 @@ def __init__(
974975
Args:
975976
feature_set: An `n_f x d_f`-dim tensor denoting the features to be
976977
appended to the inputs.
978+
skip_expand: A boolean indicating whether to expand the input tensor
979+
before appending features. This is intended for use with an
980+
`InputPerturbation`. If `True`, the input tensor will be expected
981+
to be of shape `batch_shape x (q * n_f) x d`.
977982
transform_on_train: A boolean indicating whether to apply the
978983
transforms in train() mode. Default: False.
979984
transform_on_eval: A boolean indicating whether to apply the
@@ -984,6 +989,7 @@ def __init__(
984989
super().__init__()
985990
if feature_set.dim() != 2:
986991
raise ValueError("`feature_set` must be an `n_f x d_f`-dim tensor!")
992+
self.skip_expand = skip_expand
987993
self.register_buffer("feature_set", feature_set)
988994
self.transform_on_train = transform_on_train
989995
self.transform_on_eval = transform_on_eval
@@ -1003,14 +1009,22 @@ def transform(self, X: Tensor) -> Tensor:
10031009
sample paths.
10041010
10051011
Args:
1006-
X: A `batch_shape x q x d`-dim tensor of inputs.
1012+
X: A `batch_shape x q x d`-dim tensor of inputs. If `self.skip_expand` is
1013+
`True`, then `X` should be of shape `batch_shape x (q * n_f) x d`,
1014+
typically obtained by passing a `batch_shape x q x d` shape input
1015+
through an `InputPerturbation` with `n_f` perturbation values.
10071016
10081017
Returns:
10091018
A `batch_shape x (q * n_f) x (d + d_f)`-dim tensor of appended inputs.
10101019
"""
1011-
expanded_X = X.unsqueeze(dim=-2).expand(
1012-
*X.shape[:-1], self.feature_set.shape[0], -1
1013-
)
1020+
if self.skip_expand:
1021+
expanded_X = X.view(
1022+
*X.shape[:-2], -1, self.feature_set.shape[0], X.shape[-1]
1023+
)
1024+
else:
1025+
expanded_X = X.unsqueeze(dim=-2).expand(
1026+
*X.shape[:-1], self.feature_set.shape[0], -1
1027+
)
10141028
expanded_features = self.feature_set.expand(*expanded_X.shape[:-1], -1)
10151029
appended_X = torch.cat([expanded_X, expanded_features], dim=-1)
10161030
return appended_X.view(*X.shape[:-2], -1, appended_X.shape[-1])

test/models/transforms/test_input.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,29 @@ def test_append_features(self):
823823
self.assertEqual(transform.feature_set.device.type, "cpu")
824824
self.assertEqual(transform.feature_set.dtype, torch.half)
825825

826+
def test_w_skip_expand(self):
827+
for dtype in (torch.float, torch.double):
828+
tkwargs = {"device": self.device, "dtype": dtype}
829+
feature_set = torch.tensor([[0.0], [1.0]], **tkwargs)
830+
append_tf = AppendFeatures(feature_set=feature_set, skip_expand=True).eval()
831+
perturbation_set = torch.tensor([[0.0, 0.5], [1.0, 1.5]], **tkwargs)
832+
pert_tf = InputPerturbation(perturbation_set=perturbation_set).eval()
833+
test_X = torch.tensor([[0.0, 0.0], [1.0, 1.0]], **tkwargs)
834+
tf_X = append_tf(pert_tf(test_X))
835+
expected_X = torch.tensor(
836+
[
837+
[0.0, 0.5, 0.0],
838+
[1.0, 1.5, 1.0],
839+
[1.0, 1.5, 0.0],
840+
[2.0, 2.5, 1.0],
841+
],
842+
**tkwargs,
843+
)
844+
self.assertTrue(torch.allclose(tf_X, expected_X))
845+
# Batched evaluation.
846+
tf_X = append_tf(pert_tf(test_X.expand(3, 5, -1, -1)))
847+
self.assertTrue(torch.allclose(tf_X, expected_X.expand(3, 5, -1, -1)))
848+
826849

827850
class TestFilterFeatures(BotorchTestCase):
828851
def test_filter_features(self):

0 commit comments

Comments
 (0)