Skip to content

Commit 208470e

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Introducing subset_transform decorator (#1468)
Summary: Pull Request resolved: #1468 Most `InputTransforms` have an `indices` field that specifies a feature subset on which to apply the transform. This diff introduces the `subset_transform` decorator, which separates out the indexing logic, thereby simplifying the implementation of mutliple `transform` and `untransform` methods, as well as adding support for `indices` in `InputPerturbation`. Reviewed By: saitcakmak Differential Revision: D40620269 fbshipit-source-id: 05fe093569d12bb19f9597044317efe270ba9355
1 parent d34a568 commit 208470e

File tree

3 files changed

+113
-73
lines changed

3 files changed

+113
-73
lines changed

botorch/models/transforms/input.py

Lines changed: 73 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import torch
2424
from botorch.exceptions.errors import BotorchTensorDimensionError
25-
from botorch.models.transforms.utils import expand_and_copy_tensor
25+
from botorch.models.transforms.utils import subset_transform
2626
from botorch.models.utils import fantasize
2727
from botorch.utils.rounding import approximate_round
2828
from gpytorch import Module as GPyTorchModule
@@ -387,6 +387,7 @@ def learn_coefficients(self, value: bool) -> None:
387387
"""
388388
self._learn_coefficients = value
389389

390+
@subset_transform
390391
def _transform(self, X: Tensor) -> Tensor:
391392
r"""Apply affine transformation to input.
392393
@@ -400,13 +401,9 @@ def _transform(self, X: Tensor) -> Tensor:
400401
self._check_shape(X)
401402
self._update_coefficients(X)
402403
self._to(X)
403-
if hasattr(self, "indices"):
404-
X_new = X.clone()
405-
a, b = self.coefficient[..., self.indices], self.offset[..., self.indices]
406-
X_new[..., self.indices] = (X_new[..., self.indices] - b) / a
407-
return X_new
408404
return (X - self.offset) / self.coefficient
409405

406+
@subset_transform
410407
def _untransform(self, X: Tensor) -> Tensor:
411408
r"""Apply inverse of affine transformation.
412409
@@ -417,11 +414,6 @@ def _untransform(self, X: Tensor) -> Tensor:
417414
A `batch_shape x n x d`-dim tensor of un-transformed inputs.
418415
"""
419416
self._to(X)
420-
if hasattr(self, "indices"):
421-
X_new = X.clone()
422-
a, b = self.coefficient[..., self.indices], self.offset[..., self.indices]
423-
X_new[..., self.indices] = a * X_new[..., self.indices] + b
424-
return X_new
425417
return self.coefficient * X + self.offset
426418

427419
def equals(self, other: InputTransform) -> bool:
@@ -523,18 +515,22 @@ def __init__(
523515
min_range: Amount of noise to add to the range to ensure no division by
524516
zero errors.
525517
"""
518+
transform_dimension = d if indices is None else len(indices)
526519
if bounds is not None:
527-
if bounds.size(-1) != d:
520+
if indices is not None and bounds.size(-1) == d:
521+
bounds = bounds[..., indices]
522+
if bounds.size(-1) != transform_dimension:
528523
raise BotorchTensorDimensionError(
529-
"Dimensions of provided `bounds` are incompatible with `d`!"
524+
"Dimensions of provided `bounds` are incompatible with "
525+
f"transform_dimension = {transform_dimension}!"
530526
)
531527
offset = bounds[..., 0:1, :]
532528
coefficient = bounds[..., 1:2, :] - offset
533529
if coefficient.ndim > 2:
534530
batch_shape = coefficient.shape[:-2]
535531
else:
536-
coefficient = torch.ones(*batch_shape, 1, d)
537-
offset = torch.zeros(*batch_shape, 1, d)
532+
coefficient = torch.ones(*batch_shape, 1, transform_dimension)
533+
offset = torch.zeros(*batch_shape, 1, transform_dimension)
538534
self.learn_coefficients = True
539535
super().__init__(
540536
d=d,
@@ -569,7 +565,6 @@ def learn_bounds(self) -> bool:
569565
def _update_coefficients(self, X) -> None:
570566
"""Computes the normalization bounds and updates the affine
571567
coefficients, which determine the base class's behavior.
572-
NOTE: could drop inactive indices from bounds computation.
573568
"""
574569
# Aggregate mins and ranges over extra batch and marginal dims
575570
batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X`
@@ -616,10 +611,11 @@ def __init__(
616611
min_std: Amount of noise to add to the standard deviation to ensure no
617612
division by zero errors.
618613
"""
614+
transform_dimension = d if indices is None else len(indices)
619615
super().__init__(
620616
d=d,
621-
coefficient=torch.ones(*batch_shape, 1, d),
622-
offset=torch.zeros(*batch_shape, 1, d),
617+
coefficient=torch.ones(*batch_shape, 1, transform_dimension),
618+
offset=torch.zeros(*batch_shape, 1, transform_dimension),
623619
indices=indices,
624620
batch_shape=batch_shape,
625621
transform_on_train=transform_on_train,
@@ -641,7 +637,6 @@ def means(self):
641637
def _update_coefficients(self, X: Tensor) -> None:
642638
"""Computes the normalization bounds and updates the affine
643639
coefficients, which determine the base class's behavior.
644-
NOTE: could drop inactive indices from bounds computation.
645640
"""
646641
# Aggregate means and standard deviations over extra batch and marginal dims
647642
batch_ndim = min(len(self.batch_shape), X.ndim - 2) # batch rank of `X`
@@ -722,6 +717,7 @@ def __init__(
722717
self.approximate = approximate
723718
self.tau = tau
724719

720+
@subset_transform
725721
def transform(self, X: Tensor) -> Tensor:
726722
r"""Round the inputs.
727723
@@ -731,14 +727,7 @@ def transform(self, X: Tensor) -> Tensor:
731727
Returns:
732728
A `batch_shape x n x d`-dim tensor of rounded inputs.
733729
"""
734-
X_rounded = X.clone()
735-
X_int = X_rounded[..., self.indices]
736-
if self.approximate:
737-
X_int = approximate_round(X_int, tau=self.tau)
738-
else:
739-
X_int = X_int.round()
740-
X_rounded[..., self.indices] = X_int
741-
return X_rounded
730+
return approximate_round(X, tau=self.tau) if self.approximate else X.round()
742731

743732
def equals(self, other: InputTransform) -> bool:
744733
r"""Check if another input transform is equivalent.
@@ -787,6 +776,7 @@ def __init__(
787776
self.transform_on_fantasize = transform_on_fantasize
788777
self.reverse = reverse
789778

779+
@subset_transform
790780
def _transform(self, X: Tensor) -> Tensor:
791781
r"""Log transform the inputs.
792782
@@ -796,10 +786,9 @@ def _transform(self, X: Tensor) -> Tensor:
796786
Returns:
797787
A `batch_shape x n x d`-dim tensor of transformed inputs.
798788
"""
799-
X_new = X.clone()
800-
X_new[..., self.indices] = X_new[..., self.indices].log10()
801-
return X_new
789+
return X.log10()
802790

791+
@subset_transform
803792
def _untransform(self, X: Tensor) -> Tensor:
804793
r"""Reverse the log transformation.
805794
@@ -809,9 +798,7 @@ def _untransform(self, X: Tensor) -> Tensor:
809798
Returns:
810799
A `batch_shape x n x d`-dim tensor of un-normalized inputs.
811800
"""
812-
X_new = X.clone()
813-
X_new[..., self.indices] = 10.0 ** X_new[..., self.indices]
814-
return X_new
801+
return 10.0**X
815802

816803

817804
class Warp(ReversibleInputTransform, GPyTorchModule):
@@ -915,6 +902,7 @@ def _set_concentration(self, i: int, value: Union[float, Tensor]) -> None:
915902
value = torch.as_tensor(value).to(self.concentration0)
916903
self.initialize(**{f"concentration{i}": value})
917904

905+
@subset_transform
918906
def _transform(self, X: Tensor) -> Tensor:
919907
r"""Warp the inputs through the Kumaraswamy CDF.
920908
@@ -927,20 +915,16 @@ def _transform(self, X: Tensor) -> Tensor:
927915
A `input_batch_shape x (batch_shape) x n x d`-dim tensor of transformed
928916
inputs.
929917
"""
930-
X_tf = expand_and_copy_tensor(X=X, batch_shape=self.batch_shape)
931-
k = Kumaraswamy(
932-
concentration1=self.concentration1, concentration0=self.concentration0
933-
)
934-
# normalize to [eps, 1-eps]
935-
X_tf[..., self.indices] = k.cdf(
918+
# normalize to [eps, 1-eps], IDEA: could use Normalize and ChainedTransform.
919+
return self._k.cdf(
936920
torch.clamp(
937-
X_tf[..., self.indices] * self._X_range + self._X_min,
921+
X * self._X_range + self._X_min,
938922
self._X_min,
939923
1.0 - self._X_min,
940924
)
941925
)
942-
return X_tf
943926

927+
@subset_transform
944928
def _untransform(self, X: Tensor) -> Tensor:
945929
r"""Warp the inputs through the Kumaraswamy inverse CDF.
946930
@@ -957,15 +941,16 @@ def _untransform(self, X: Tensor) -> Tensor:
957941
"The right most batch dims of X must match self.batch_shape: "
958942
f"({self.batch_shape})."
959943
)
960-
X_tf = X.clone()
961-
k = Kumaraswamy(
962-
concentration1=self.concentration1, concentration0=self.concentration0
963-
)
964944
# unnormalize from [eps, 1-eps] to [0,1]
965-
X_tf[..., self.indices] = (
966-
(k.icdf(X_tf[..., self.indices]) - self._X_min) / self._X_range
967-
).clamp(0.0, 1.0)
968-
return X_tf
945+
return ((self._k.icdf(X) - self._X_min) / self._X_range).clamp(0.0, 1.0)
946+
947+
@property
948+
def _k(self) -> Kumaraswamy:
949+
"""Returns a Kumaraswamy distribution with the concentration parameters."""
950+
return Kumaraswamy(
951+
concentration1=self.concentration1,
952+
concentration0=self.concentration0,
953+
)
969954

970955

971956
class AppendFeatures(InputTransform, Module):
@@ -1225,6 +1210,7 @@ def __init__(
12251210
self,
12261211
perturbation_set: Union[Tensor, Callable[[Tensor], Tensor]],
12271212
bounds: Optional[Tensor] = None,
1213+
indices: Optional[List[int]] = None,
12281214
multiplicative: bool = False,
12291215
transform_on_train: bool = False,
12301216
transform_on_eval: bool = True,
@@ -1240,6 +1226,10 @@ def __init__(
12401226
bounds: A `2 x d`-dim tensor of lower and upper bounds for each
12411227
column of the input. If given, the perturbed inputs will be
12421228
clamped to these bounds.
1229+
indices: A list of indices specifying a subset of inputs on which to apply
1230+
the transform. Note that `len(indices)` should be equal to the second
1231+
dimension of `perturbation_set` and `bounds`. The dimensionality of
1232+
the input `X.shape[-1]` can be larger if we only transform a subset.
12431233
multiplicative: A boolean indicating whether the input perturbations
12441234
are additive or multiplicative. If True, inputs will be multiplied
12451235
with the perturbations.
@@ -1270,6 +1260,8 @@ def __init__(
12701260
self.register_buffer("bounds", bounds)
12711261
else:
12721262
self.bounds = None
1263+
self.register_buffer("_perturbations", None)
1264+
self.indices = indices
12731265
self.multiplicative = multiplicative
12741266
self.transform_on_train = transform_on_train
12751267
self.transform_on_eval = transform_on_eval
@@ -1294,21 +1286,36 @@ def transform(self, X: Tensor) -> Tensor:
12941286
Returns:
12951287
A `batch_shape x (q * n_p) x d`-dim tensor of perturbed inputs.
12961288
"""
1297-
if isinstance(self.perturbation_set, Tensor):
1298-
perturbations = self.perturbation_set
1299-
else:
1300-
perturbations = self.perturbation_set(X)
1301-
expanded_X = X.unsqueeze(dim=-2).expand(
1302-
*X.shape[:-1], perturbations.shape[-2], -1
1303-
)
1304-
expanded_perturbations = perturbations.expand(*expanded_X.shape[:-1], -1)
1305-
if self.multiplicative:
1306-
perturbed_inputs = expanded_X * expanded_perturbations
1307-
else:
1308-
perturbed_inputs = expanded_X + expanded_perturbations
1309-
perturbed_inputs = perturbed_inputs.reshape(*X.shape[:-2], -1, X.shape[-1])
1289+
# NOTE: If we had access to n_p without evaluating _perturbations when the
1290+
# perturbation_set is a function, we could move this into `_transform`.
1291+
# Further, we could remove the two `transpose` calls below if one were
1292+
# willing to accept a different ordering of the transformed output.
1293+
self._perturbations = self._expanded_perturbations(X)
1294+
# make space for n_p dimension, switch n_p with n after transform, and flatten.
1295+
return self._transform(X.unsqueeze(-3)).transpose(-3, -2).flatten(-3, -2)
1296+
1297+
@subset_transform
1298+
def _transform(self, X: Tensor):
1299+
p = self._perturbations
1300+
Y = X * p if self.multiplicative else X + p
13101301
if self.bounds is not None:
1311-
perturbed_inputs = torch.maximum(
1312-
torch.minimum(perturbed_inputs, self.bounds[1]), self.bounds[0]
1313-
)
1314-
return perturbed_inputs
1302+
return torch.maximum(torch.minimum(Y, self.bounds[1]), self.bounds[0])
1303+
return Y
1304+
1305+
@property
1306+
def batch_shape(self):
1307+
"""Returns a shape tuple such that `subset_transform` pre-allocates
1308+
a (b x n_p x n x d) - dim tensor, where `b` is the batch shape of the
1309+
input `X` of the transform and `n_p` is the number of perturbations.
1310+
NOTE: this function is dependent on calling `_expanded_perturbations(X)`
1311+
because `n_p` is inaccessible otherwise if `perturbation_set` is a function.
1312+
"""
1313+
return self._perturbations.shape[:-2]
1314+
1315+
def _expanded_perturbations(self, X: Tensor) -> Tensor:
1316+
p = self.perturbation_set
1317+
if isinstance(p, Tensor):
1318+
p = p.expand(X.shape[-2], *p.shape) # p is batch_shape x n x n_p x d
1319+
else:
1320+
p = p(X) if self.indices is None else p(X[..., self.indices])
1321+
return p.transpose(-3, -2) # p is batch_shape x n_p x n x d

botorch/models/transforms/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from __future__ import annotations
88

9+
from functools import wraps
10+
911
from typing import Tuple
1012

1113
import torch
@@ -111,3 +113,18 @@ def expand_and_copy_tensor(X: Tensor, batch_shape: torch.Size) -> Tensor:
111113
)
112114
expand_shape = batch_shape + X.shape[-2:]
113115
return X.expand(expand_shape).clone()
116+
117+
118+
def subset_transform(transform):
119+
r"""Decorator of an input transform function to separate out indexing logic."""
120+
121+
@wraps(transform)
122+
def f(self, X: Tensor) -> Tensor:
123+
if not hasattr(self, "indices") or self.indices is None:
124+
return transform(self, X)
125+
has_shape = hasattr(self, "batch_shape")
126+
Y = expand_and_copy_tensor(X, self.batch_shape) if has_shape else X.clone()
127+
Y[..., self.indices] = transform(self, X[..., self.indices])
128+
return Y
129+
130+
return f

test/models/transforms/test_input.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def test_normalize(self):
176176
self.assertTrue(nlz.learn_bounds)
177177
self.assertTrue(nlz.training)
178178
self.assertEqual(nlz._d, 2)
179-
self.assertEqual(nlz.mins.shape, torch.Size([1, 2]))
180-
self.assertEqual(nlz.ranges.shape, torch.Size([1, 2]))
179+
self.assertEqual(nlz.mins.shape, torch.Size([1, 1]))
180+
self.assertEqual(nlz.ranges.shape, torch.Size([1, 1]))
181181
self.assertEqual(len(nlz.indices), 1)
182182
self.assertTrue((nlz.indices == torch.tensor([0], dtype=torch.long)).all())
183183

@@ -284,7 +284,7 @@ def test_normalize(self):
284284
expected_bounds = torch.cat(
285285
[X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]],
286286
dim=-2,
287-
)
287+
)[..., indices]
288288
self.assertTrue(
289289
torch.allclose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)
290290
)
@@ -349,17 +349,17 @@ def test_standardize(self):
349349
stdz = InputStandardize(d=2, indices=[0])
350350
self.assertTrue(stdz.training)
351351
self.assertEqual(stdz._d, 2)
352-
self.assertEqual(stdz.means.shape, torch.Size([1, 2]))
353-
self.assertEqual(stdz.stds.shape, torch.Size([1, 2]))
352+
self.assertEqual(stdz.means.shape, torch.Size([1, 1]))
353+
self.assertEqual(stdz.stds.shape, torch.Size([1, 1]))
354354
self.assertEqual(len(stdz.indices), 1)
355355
self.assertTrue(
356356
torch.equal(stdz.indices, torch.tensor([0], dtype=torch.long))
357357
)
358358
stdz = InputStandardize(d=2, indices=[0], batch_shape=torch.Size([3]))
359359
self.assertTrue(stdz.training)
360360
self.assertEqual(stdz._d, 2)
361-
self.assertEqual(stdz.means.shape, torch.Size([3, 1, 2]))
362-
self.assertEqual(stdz.stds.shape, torch.Size([3, 1, 2]))
361+
self.assertEqual(stdz.means.shape, torch.Size([3, 1, 1]))
362+
self.assertEqual(stdz.stds.shape, torch.Size([3, 1, 1]))
363363
self.assertEqual(len(stdz.indices), 1)
364364
self.assertTrue(
365365
torch.equal(stdz.indices, torch.tensor([0], dtype=torch.long))
@@ -1308,3 +1308,19 @@ def perturbation_generator(X: Tensor) -> Tensor:
13081308
dim=-2,
13091309
)
13101310
self.assertTrue(torch.allclose(transformed, expected))
1311+
1312+
# testing same heteroscedastic transform with subset of indices
1313+
indices = [0, 1]
1314+
subset_transform = InputPerturbation(
1315+
perturbation_set=perturbation_generator, indices=indices
1316+
).eval()
1317+
X_repeat = X.repeat(1, 1, 2)
1318+
subset_transformed = subset_transform(X_repeat)
1319+
# first set of two indices are the same as with previous transform
1320+
self.assertTrue(torch.allclose(subset_transformed[..., :2], expected))
1321+
1322+
# second set of two indices are untransformed but have expanded batch shape
1323+
num_pert = subset_transform.batch_shape[-1]
1324+
sec_expected = X.unsqueeze(-2).expand(*X.shape[:-1], num_pert, -1)
1325+
sec_expected = sec_expected.flatten(-3, -2)
1326+
self.assertTrue(torch.allclose(subset_transformed[..., 2:], sec_expected))

0 commit comments

Comments
 (0)