Skip to content

Commit 5f2028d

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add is_one_to_many attribute to input transforms (#1396)
Summary: Pull Request resolved: #1396 Adds a simple attribute to replace `isinstance(tf, (AppendFeatures, InputPerturbation))` checks. For chained transforms, this is set to true if at least one of the child transforms is one-to-many. Reviewed By: Balandat Differential Revision: D39455329 fbshipit-source-id: 3381b6299c3f83a4b68fbf0d579481ea2ffd0417
1 parent 4ee9e03 commit 5f2028d

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

botorch/models/transforms/input.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class InputTransform(ABC):
4141
between `gpytorch.module.Module` and `torch.nn.Module` in `Warp`.
4242
4343
Properties:
44+
is_one_to_many: A boolean denoting whether the transform produces
45+
multiple values for each input.
4446
transform_on_train: A boolean indicating whether to apply the
4547
transform in train() mode.
4648
transform_on_eval: A boolean indicating whether to apply the
@@ -51,6 +53,7 @@ class InputTransform(ABC):
5153
:meta private:
5254
"""
5355

56+
is_one_to_many: bool = False
5457
transform_on_eval: bool
5558
transform_on_train: bool
5659
transform_on_fantasize: bool
@@ -177,6 +180,7 @@ def __init__(self, **transforms: InputTransform) -> None:
177180
self.transform_on_eval = False
178181
self.transform_on_fantasize = False
179182
for tf in transforms.values():
183+
self.is_one_to_many |= tf.is_one_to_many
180184
self.transform_on_train |= tf.transform_on_train
181185
self.transform_on_eval |= tf.transform_on_eval
182186
self.transform_on_fantasize |= tf.transform_on_fantasize
@@ -999,6 +1003,8 @@ class AppendFeatures(InputTransform, Module):
9991003
>>> risk_measure_samples = risk_measure(posterior_samples)
10001004
"""
10011005

1006+
is_one_to_many: bool = True
1007+
10021008
def __init__(
10031009
self,
10041010
feature_set: Optional[Tensor] = None,
@@ -1196,6 +1202,8 @@ class InputPerturbation(InputTransform, Module):
11961202
https://botorch.org/tutorials/risk_averse_bo_with_input_perturbations.
11971203
"""
11981204

1205+
is_one_to_many: bool = True
1206+
11991207
def __init__(
12001208
self,
12011209
perturbation_set: Union[Tensor, Callable[[Tensor], Tensor]],

test/models/transforms/test_input.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,10 @@ def test_abstract_base_input_transform(self):
126126
self.assertTrue(torch.equal(ipt5(X), X))
127127

128128
def test_normalize(self):
129-
130129
for dtype in (torch.float, torch.double):
131-
132130
# basic init, learned bounds
133131
nlz = Normalize(d=2)
132+
self.assertFalse(nlz.is_one_to_many)
134133
self.assertTrue(nlz.learn_bounds)
135134
self.assertTrue(nlz.training)
136135
self.assertEqual(nlz._d, 2)
@@ -316,7 +315,6 @@ def test_normalize(self):
316315

317316
def test_standardize(self):
318317
for dtype in (torch.float, torch.double):
319-
320318
# basic init
321319
stdz = InputStandardize(d=2)
322320
self.assertTrue(stdz.training)
@@ -455,7 +453,6 @@ def test_standardize(self):
455453
self.assertFalse(stdz7.equals(stdz8))
456454

457455
def test_chained_input_transform(self):
458-
459456
ds = (1, 2)
460457
batch_shapes = (torch.Size(), torch.Size([2]))
461458
dtypes = (torch.float, torch.double)
@@ -473,6 +470,7 @@ def test_chained_input_transform(self):
473470
self.assertEqual(sorted(tf.keys()), ["stz_fixed", "stz_learned"])
474471
self.assertEqual(tf["stz_fixed"], tf1)
475472
self.assertEqual(tf["stz_learned"], tf2)
473+
self.assertFalse(tf.is_one_to_many)
476474

477475
X = torch.rand(*batch_shape, 4, d, device=self.device, dtype=dtype)
478476
X_tf = tf(X)
@@ -515,6 +513,11 @@ def test_chained_input_transform(self):
515513
tf = ChainedInputTransform(stz_fixed=tf1, stz_learned=tf2)
516514
self.assertTrue(torch.equal(tf.preprocess_transform(X), tf1.transform(X)))
517515

516+
# test one-to-many
517+
tf2 = InputPerturbation(perturbation_set=bounds)
518+
tf = ChainedInputTransform(stz=tf1, pert=tf2)
519+
self.assertTrue(tf.is_one_to_many)
520+
518521
def test_round_transform(self):
519522
for dtype in (torch.float, torch.double):
520523
# basic init
@@ -811,6 +814,7 @@ def test_append_features(self):
811814
torch.linspace(0, 1, 6).view(3, 2).to(device=self.device, dtype=dtype)
812815
)
813816
transform = AppendFeatures(feature_set=feature_set)
817+
self.assertTrue(transform.is_one_to_many)
814818
X = torch.rand(4, 5, 3, device=self.device, dtype=dtype)
815819
# in train - no transform
816820
transform.train()
@@ -1173,6 +1177,7 @@ def test_input_perturbation(self):
11731177
[[0.5, -0.3], [0.2, 0.4], [-0.7, 0.1]], device=self.device, dtype=dtype
11741178
)
11751179
transform = InputPerturbation(perturbation_set=perturbation_set)
1180+
self.assertTrue(transform.is_one_to_many)
11761181
X = torch.tensor(
11771182
[[[0.5, 0.5], [0.9, 0.7]], [[0.3, 0.2], [0.1, 0.4]]],
11781183
device=self.device,

0 commit comments

Comments
 (0)