Skip to content

Commit 2ecdbc1

Browse files
jduerholtfacebook-github-bot
authored andcommitted
AppendFeaturesFromCallable InputTransform (#1354)
Summary: ## Motivation This PR adds the functionality of transfer learning and feature engineering within a model via a new input transform called `AppendFeaturesFromCallable` as discussed in #1307. This implementation does not alter the original `AppendFeatures` transform but adds a new one to keep things clean and separate. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #1354 Test Plan: Unit tests. Reviewed By: Balandat Differential Revision: D38907448 Pulled By: saitcakmak fbshipit-source-id: 05ed63540ce486e89cdf0e5cef23fa50ab49bc4a
1 parent b663c72 commit 2ecdbc1

File tree

2 files changed

+300
-20
lines changed

2 files changed

+300
-20
lines changed

botorch/models/transforms/input.py

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from abc import ABC, abstractmethod
2020
from collections import OrderedDict
21-
from typing import Callable, List, Optional, Union
21+
from typing import Any, Callable, Dict, List, Optional, Union
2222

2323
import torch
2424
from botorch.exceptions.errors import BotorchTensorDimensionError
@@ -948,11 +948,13 @@ def _untransform(self, X: Tensor) -> Tensor:
948948

949949

950950
class AppendFeatures(InputTransform, Module):
951-
r"""A transform that appends the input with a given set of features.
951+
r"""A transform that appends the input with a given set of features either
952+
provided beforehand or generated on the fly via a callable.
952953
953-
As an example, this can be used with `RiskMeasureMCObjective` to optimize risk
954-
measures as described in [Cakmak2020risk]_. A tutorial notebook implementing the
955-
rhoKG acqusition function introduced in [Cakmak2020risk]_ can be found at
954+
As an example, the predefined set of features can be used with
955+
`RiskMeasureMCObjective` to optimize risk measures as described in
956+
[Cakmak2020risk]_. A tutorial notebook implementing the rhoKG acqusition
957+
function introduced in [Cakmak2020risk]_ can be found at
956958
https://botorch.org/tutorials/risk_averse_bo_with_environmental_variables.
957959
958960
The steps for using this to obtain samples of a risk measure are as follows:
@@ -973,6 +975,11 @@ class AppendFeatures(InputTransform, Module):
973975
since the `feature_set` does not fully represent the distribution of the
974976
environmental variable.
975977
978+
Possible examples for using a callable include statistical models that are built on
979+
PyTorch, built-in mathematical operations such as torch.sum, or custom scripted
980+
functions. By this, this input transform allows for advanced feature engineering
981+
and transfer learning models within the optimization loop.
982+
976983
Example:
977984
>>> # We consider 1D `x` and 1D `w`, with `W` having a
978985
>>> # uniform distribution over [0, 1]
@@ -994,21 +1001,34 @@ class AppendFeatures(InputTransform, Module):
9941001

9951002
def __init__(
9961003
self,
997-
feature_set: Tensor,
1004+
feature_set: Optional[Tensor] = None,
1005+
f: Optional[Callable[[Tensor], Tensor]] = None,
1006+
indices: Optional[List[int]] = None,
1007+
fkwargs: Optional[Dict[str, Any]] = None,
9981008
skip_expand: bool = False,
9991009
transform_on_train: bool = False,
10001010
transform_on_eval: bool = True,
10011011
transform_on_fantasize: bool = False,
10021012
) -> None:
1003-
r"""Append `feature_set` to each input.
1013+
r"""Append `feature_set` to each input or generate a set of features to
1014+
append on the fly via a callable.
10041015
10051016
Args:
10061017
feature_set: An `n_f x d_f`-dim tensor denoting the features to be
1007-
appended to the inputs.
1018+
appended to the inputs. Default: None.
1019+
f: A callable mapping a `batch_shape x q x d`-dim input tensor `X`
1020+
to a `batch_shape x q x n_f x d_f`-dimensional output tensor.
1021+
Default: None.
1022+
indices: List of indices denoting the indices of the features to be
1023+
passed into f. Per default all features are passed to `f`.
1024+
Default: None.
1025+
fkwargs: Dictionary of keyword arguments passed to the callable `f`.
1026+
Default: None.
10081027
skip_expand: A boolean indicating whether to expand the input tensor
10091028
before appending features. This is intended for use with an
10101029
`InputPerturbation`. If `True`, the input tensor will be expected
1011-
to be of shape `batch_shape x (q * n_f) x d`.
1030+
to be of shape `batch_shape x (q * n_f) x d`. Not implemented
1031+
in combination with a callable.
10121032
transform_on_train: A boolean indicating whether to apply the
10131033
transforms in train() mode. Default: False.
10141034
transform_on_eval: A boolean indicating whether to apply the
@@ -1017,16 +1037,44 @@ def __init__(
10171037
transform when called from within a `fantasize` call. Default: False.
10181038
"""
10191039
super().__init__()
1020-
if feature_set.dim() != 2:
1021-
raise ValueError("`feature_set` must be an `n_f x d_f`-dim tensor!")
1040+
if (feature_set is None) and (f is None):
1041+
raise ValueError(
1042+
"Either a `feature_set` or a callable `f` has to be provided."
1043+
)
1044+
if (feature_set is not None) and (f is not None):
1045+
raise ValueError(
1046+
"Only one can be used: either `feature_set` or callable `f`."
1047+
)
1048+
if feature_set is not None:
1049+
if feature_set.dim() != 2:
1050+
raise ValueError("`feature_set` must be an `n_f x d_f`-dim tensor!")
1051+
self.register_buffer("feature_set", feature_set)
1052+
self._f = None
1053+
if f is not None:
1054+
if skip_expand:
1055+
raise ValueError(
1056+
"`skip_expand` option is not supported in case of using a callable"
1057+
)
1058+
if (indices is not None) and (len(indices) == 0):
1059+
raise ValueError("`indices` list is empty!")
1060+
if indices is not None:
1061+
indices = torch.tensor(indices, dtype=torch.long)
1062+
if len(indices.unique()) != len(indices):
1063+
raise ValueError("Elements of `indices` tensor must be unique!")
1064+
self.indices = indices
1065+
else:
1066+
self.indices = slice(None)
1067+
self._f = f
1068+
self.fkwargs = fkwargs or {}
1069+
10221070
self.skip_expand = skip_expand
1023-
self.register_buffer("feature_set", feature_set)
10241071
self.transform_on_train = transform_on_train
10251072
self.transform_on_eval = transform_on_eval
10261073
self.transform_on_fantasize = transform_on_fantasize
10271074

10281075
def transform(self, X: Tensor) -> Tensor:
1029-
r"""Transform the inputs by appending `feature_set` to each input.
1076+
r"""Transform the inputs by appending `feature_set` to each input or
1077+
by generating a set of features to be appended on the fly via a callable.
10301078
10311079
For each `1 x d`-dim element in the input tensor, this will produce
10321080
an `n_f x (d + d_f)`-dim tensor with `feature_set` appended as the last `d_f`
@@ -1047,15 +1095,20 @@ def transform(self, X: Tensor) -> Tensor:
10471095
Returns:
10481096
A `batch_shape x (q * n_f) x (d + d_f)`-dim tensor of appended inputs.
10491097
"""
1098+
if self._f is not None:
1099+
expanded_features = self._f(X[..., self.indices], **self.fkwargs)
1100+
n_f = expanded_features.shape[-2]
1101+
else:
1102+
n_f = self.feature_set.shape[-2]
1103+
10501104
if self.skip_expand:
1051-
expanded_X = X.view(
1052-
*X.shape[:-2], -1, self.feature_set.shape[0], X.shape[-1]
1053-
)
1105+
expanded_X = X.view(*X.shape[:-2], -1, n_f, X.shape[-1])
10541106
else:
1055-
expanded_X = X.unsqueeze(dim=-2).expand(
1056-
*X.shape[:-1], self.feature_set.shape[0], -1
1057-
)
1058-
expanded_features = self.feature_set.expand(*expanded_X.shape[:-1], -1)
1107+
expanded_X = X.unsqueeze(dim=-2).expand(*X.shape[:-1], n_f, -1)
1108+
1109+
if self._f is None:
1110+
expanded_features = self.feature_set.expand(*expanded_X.shape[:-1], -1)
1111+
10591112
appended_X = torch.cat([expanded_X, expanded_features], dim=-1)
10601113
return appended_X.view(*X.shape[:-2], -1, appended_X.shape[-1])
10611114

test/models/transforms/test_input.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,233 @@ def test_w_skip_expand(self):
860860
tf_X = append_tf(pert_tf(test_X.expand(3, 5, -1, -1)))
861861
self.assertTrue(torch.allclose(tf_X, expected_X.expand(3, 5, -1, -1)))
862862

863+
def test_w_f(self):
864+
def f1(x: Tensor, n_f: int = 1) -> Tensor:
865+
result = torch.sum(x, dim=-1, keepdim=True).unsqueeze(-2)
866+
return result.expand(*result.shape[:-2], n_f, -1)
867+
868+
def f2(x: Tensor, n_f: int = 1) -> Tensor:
869+
result = x[..., -2:].unsqueeze(-2)
870+
return result.expand(*result.shape[:-2], n_f, -1)
871+
872+
for dtype in [torch.float, torch.double]:
873+
tkwargs = {"device": self.device, "dtype": dtype}
874+
875+
# test init
876+
with self.assertRaises(ValueError):
877+
transform = AppendFeatures(f=f1, indices=[0, 0])
878+
with self.assertRaises(ValueError):
879+
transform = AppendFeatures(f=f1, indices=[])
880+
with self.assertRaises(ValueError):
881+
transform = AppendFeatures(f=f1, skip_expand=True)
882+
with self.assertRaises(ValueError):
883+
transform = AppendFeatures(feature_set=None, f=None)
884+
with self.assertRaises(ValueError):
885+
transform = AppendFeatures(
886+
feature_set=torch.linspace(0, 1, 6)
887+
.view(3, 2)
888+
.to(device=self.device, dtype=dtype),
889+
f=f1,
890+
)
891+
892+
# test functionality with n_f = 1
893+
X = torch.rand(1, 3, **tkwargs)
894+
transform = AppendFeatures(
895+
f=f1,
896+
transform_on_eval=True,
897+
transform_on_train=True,
898+
transform_on_fantasize=True,
899+
)
900+
X_transformed = transform(X)
901+
self.assertEqual(X_transformed.shape, torch.Size((1, 4)))
902+
self.assertTrue(torch.allclose(X.sum(dim=-1), X_transformed[..., -1]))
903+
904+
X = torch.rand(10, 3, **tkwargs)
905+
transform = AppendFeatures(
906+
f=f1,
907+
transform_on_eval=True,
908+
transform_on_train=True,
909+
transform_on_fantasize=True,
910+
)
911+
X_transformed = transform(X)
912+
self.assertEqual(X_transformed.shape, torch.Size((10, 4)))
913+
self.assertTrue(torch.allclose(X.sum(dim=-1), X_transformed[..., -1]))
914+
915+
transform = AppendFeatures(
916+
f=f1,
917+
indices=[0, 1],
918+
transform_on_eval=True,
919+
transform_on_train=True,
920+
transform_on_fantasize=True,
921+
)
922+
X_transformed = transform(X)
923+
self.assertEqual(X_transformed.shape, torch.Size((10, 4)))
924+
self.assertTrue(
925+
torch.allclose(X[..., [0, 1]].sum(dim=-1), X_transformed[..., -1])
926+
)
927+
928+
transform = AppendFeatures(
929+
f=f2,
930+
transform_on_eval=True,
931+
transform_on_train=True,
932+
transform_on_fantasize=True,
933+
)
934+
X_transformed = transform(X)
935+
self.assertEqual(X_transformed.shape, torch.Size((10, 5)))
936+
937+
X = torch.rand(1, 10, 3).to(**tkwargs)
938+
transform = AppendFeatures(
939+
f=f1,
940+
transform_on_eval=True,
941+
transform_on_train=True,
942+
transform_on_fantasize=True,
943+
)
944+
X_transformed = transform(X)
945+
self.assertEqual(X_transformed.shape, torch.Size((1, 10, 4)))
946+
947+
X = torch.rand(1, 1, 3).to(**tkwargs)
948+
transform = AppendFeatures(
949+
f=f1,
950+
transform_on_eval=True,
951+
transform_on_train=True,
952+
transform_on_fantasize=True,
953+
)
954+
X_transformed = transform(X)
955+
self.assertEqual(X_transformed.shape, torch.Size((1, 1, 4)))
956+
957+
X = torch.rand(2, 10, 3).to(**tkwargs)
958+
transform = AppendFeatures(
959+
f=f1,
960+
transform_on_eval=True,
961+
transform_on_train=True,
962+
transform_on_fantasize=True,
963+
)
964+
X_transformed = transform(X)
965+
self.assertEqual(X_transformed.shape, torch.Size((2, 10, 4)))
966+
967+
transform = AppendFeatures(
968+
f=f2,
969+
transform_on_eval=True,
970+
transform_on_train=True,
971+
transform_on_fantasize=True,
972+
)
973+
X_transformed = transform(X)
974+
self.assertEqual(X_transformed.shape, torch.Size((2, 10, 5)))
975+
self.assertTrue(torch.allclose(X[..., -2:], X_transformed[..., -2:]))
976+
977+
# test functionality with n_f > 1
978+
X = torch.rand(10, 3, **tkwargs)
979+
transform = AppendFeatures(
980+
f=f1,
981+
fkwargs={"n_f": 2},
982+
transform_on_eval=True,
983+
transform_on_train=True,
984+
transform_on_fantasize=True,
985+
)
986+
X_transformed = transform(X)
987+
self.assertEqual(X_transformed.shape, torch.Size((20, 4)))
988+
989+
X = torch.rand(2, 10, 3, **tkwargs)
990+
transform = AppendFeatures(
991+
f=f1,
992+
fkwargs={"n_f": 2},
993+
transform_on_eval=True,
994+
transform_on_train=True,
995+
transform_on_fantasize=True,
996+
)
997+
X_transformed = transform(X)
998+
self.assertEqual(X_transformed.shape, torch.Size((2, 20, 4)))
999+
1000+
X = torch.rand(1, 10, 3, **tkwargs)
1001+
transform = AppendFeatures(
1002+
f=f1,
1003+
fkwargs={"n_f": 2},
1004+
transform_on_eval=True,
1005+
transform_on_train=True,
1006+
transform_on_fantasize=True,
1007+
)
1008+
X_transformed = transform(X)
1009+
self.assertEqual(X_transformed.shape, torch.Size((1, 20, 4)))
1010+
1011+
X = torch.rand(1, 3, **tkwargs)
1012+
transform = AppendFeatures(
1013+
f=f1,
1014+
fkwargs={"n_f": 2},
1015+
transform_on_eval=True,
1016+
transform_on_train=True,
1017+
transform_on_fantasize=True,
1018+
)
1019+
X_transformed = transform(X)
1020+
self.assertEqual(X_transformed.shape, torch.Size((2, 4)))
1021+
1022+
X = torch.rand(10, 3, **tkwargs)
1023+
transform = AppendFeatures(
1024+
f=f2,
1025+
fkwargs={"n_f": 2},
1026+
transform_on_eval=True,
1027+
transform_on_train=True,
1028+
transform_on_fantasize=True,
1029+
)
1030+
X_transformed = transform(X)
1031+
self.assertEqual(X_transformed.shape, torch.Size((20, 5)))
1032+
1033+
X = torch.rand(2, 10, 3, **tkwargs)
1034+
transform = AppendFeatures(
1035+
f=f2,
1036+
fkwargs={"n_f": 2},
1037+
transform_on_eval=True,
1038+
transform_on_train=True,
1039+
transform_on_fantasize=True,
1040+
)
1041+
X_transformed = transform(X)
1042+
self.assertEqual(X_transformed.shape, torch.Size((2, 20, 5)))
1043+
1044+
X = torch.rand(1, 10, 3, **tkwargs)
1045+
transform = AppendFeatures(
1046+
f=f2,
1047+
fkwargs={"n_f": 2},
1048+
transform_on_eval=True,
1049+
transform_on_train=True,
1050+
transform_on_fantasize=True,
1051+
)
1052+
X_transformed = transform(X)
1053+
self.assertEqual(X_transformed.shape, torch.Size((1, 20, 5)))
1054+
1055+
X = torch.rand(1, 3, **tkwargs)
1056+
transform = AppendFeatures(
1057+
f=f2,
1058+
fkwargs={"n_f": 2},
1059+
transform_on_eval=True,
1060+
transform_on_train=True,
1061+
transform_on_fantasize=True,
1062+
)
1063+
X_transformed = transform(X)
1064+
self.assertEqual(X_transformed.shape, torch.Size((2, 5)))
1065+
1066+
# test no transform on train
1067+
X = torch.rand(10, 3).to(**tkwargs)
1068+
transform = AppendFeatures(
1069+
f=f1, transform_on_train=False, transform_on_eval=True
1070+
)
1071+
transform.train()
1072+
X_transformed = transform(X)
1073+
self.assertTrue(torch.equal(X, X_transformed))
1074+
transform.eval()
1075+
X_transformed = transform(X)
1076+
self.assertEqual(X_transformed.shape, torch.Size((10, 4)))
1077+
1078+
# test not transform on eval
1079+
X = torch.rand(10, 3).to(**tkwargs)
1080+
transform = AppendFeatures(
1081+
f=f1, transform_on_eval=False, transform_on_train=True
1082+
)
1083+
transform.eval()
1084+
X_transformed = transform(X)
1085+
self.assertTrue(torch.equal(X, X_transformed))
1086+
transform.train()
1087+
X_transformed = transform(X)
1088+
self.assertEqual(X_transformed.shape, torch.Size((10, 4)))
1089+
8631090

8641091
class TestFilterFeatures(BotorchTestCase):
8651092
def test_filter_features(self):

0 commit comments

Comments
 (0)