2222
2323import torch
2424from botorch .exceptions .errors import BotorchTensorDimensionError
25- from botorch .models .transforms .utils import expand_and_copy_tensor
25+ from botorch .models .transforms .utils import subset_transform
2626from botorch .models .utils import fantasize
2727from botorch .utils .rounding import approximate_round
2828from gpytorch import Module as GPyTorchModule
@@ -387,6 +387,7 @@ def learn_coefficients(self, value: bool) -> None:
387387 """
388388 self ._learn_coefficients = value
389389
390+ @subset_transform
390391 def _transform (self , X : Tensor ) -> Tensor :
391392 r"""Apply affine transformation to input.
392393
@@ -400,13 +401,9 @@ def _transform(self, X: Tensor) -> Tensor:
400401 self ._check_shape (X )
401402 self ._update_coefficients (X )
402403 self ._to (X )
403- if hasattr (self , "indices" ):
404- X_new = X .clone ()
405- a , b = self .coefficient [..., self .indices ], self .offset [..., self .indices ]
406- X_new [..., self .indices ] = (X_new [..., self .indices ] - b ) / a
407- return X_new
408404 return (X - self .offset ) / self .coefficient
409405
406+ @subset_transform
410407 def _untransform (self , X : Tensor ) -> Tensor :
411408 r"""Apply inverse of affine transformation.
412409
@@ -417,11 +414,6 @@ def _untransform(self, X: Tensor) -> Tensor:
417414 A `batch_shape x n x d`-dim tensor of un-transformed inputs.
418415 """
419416 self ._to (X )
420- if hasattr (self , "indices" ):
421- X_new = X .clone ()
422- a , b = self .coefficient [..., self .indices ], self .offset [..., self .indices ]
423- X_new [..., self .indices ] = a * X_new [..., self .indices ] + b
424- return X_new
425417 return self .coefficient * X + self .offset
426418
427419 def equals (self , other : InputTransform ) -> bool :
@@ -523,18 +515,22 @@ def __init__(
523515 min_range: Amount of noise to add to the range to ensure no division by
524516 zero errors.
525517 """
518+ transform_dimension = d if indices is None else len (indices )
526519 if bounds is not None :
527- if bounds .size (- 1 ) != d :
520+ if indices is not None and bounds .size (- 1 ) == d :
521+ bounds = bounds [..., indices ]
522+ if bounds .size (- 1 ) != transform_dimension :
528523 raise BotorchTensorDimensionError (
529- "Dimensions of provided `bounds` are incompatible with `d`!"
524+ "Dimensions of provided `bounds` are incompatible with "
525+ f"transform_dimension = { transform_dimension } !"
530526 )
531527 offset = bounds [..., 0 :1 , :]
532528 coefficient = bounds [..., 1 :2 , :] - offset
533529 if coefficient .ndim > 2 :
534530 batch_shape = coefficient .shape [:- 2 ]
535531 else :
536- coefficient = torch .ones (* batch_shape , 1 , d )
537- offset = torch .zeros (* batch_shape , 1 , d )
532+ coefficient = torch .ones (* batch_shape , 1 , transform_dimension )
533+ offset = torch .zeros (* batch_shape , 1 , transform_dimension )
538534 self .learn_coefficients = True
539535 super ().__init__ (
540536 d = d ,
@@ -569,7 +565,6 @@ def learn_bounds(self) -> bool:
569565 def _update_coefficients (self , X ) -> None :
570566 """Computes the normalization bounds and updates the affine
571567 coefficients, which determine the base class's behavior.
572- NOTE: could drop inactive indices from bounds computation.
573568 """
574569 # Aggregate mins and ranges over extra batch and marginal dims
575570 batch_ndim = min (len (self .batch_shape ), X .ndim - 2 ) # batch rank of `X`
@@ -616,10 +611,11 @@ def __init__(
616611 min_std: Amount of noise to add to the standard deviation to ensure no
617612 division by zero errors.
618613 """
614+ transform_dimension = d if indices is None else len (indices )
619615 super ().__init__ (
620616 d = d ,
621- coefficient = torch .ones (* batch_shape , 1 , d ),
622- offset = torch .zeros (* batch_shape , 1 , d ),
617+ coefficient = torch .ones (* batch_shape , 1 , transform_dimension ),
618+ offset = torch .zeros (* batch_shape , 1 , transform_dimension ),
623619 indices = indices ,
624620 batch_shape = batch_shape ,
625621 transform_on_train = transform_on_train ,
@@ -641,7 +637,6 @@ def means(self):
641637 def _update_coefficients (self , X : Tensor ) -> None :
642638 """Computes the normalization bounds and updates the affine
643639 coefficients, which determine the base class's behavior.
644- NOTE: could drop inactive indices from bounds computation.
645640 """
646641 # Aggregate means and standard deviations over extra batch and marginal dims
647642 batch_ndim = min (len (self .batch_shape ), X .ndim - 2 ) # batch rank of `X`
@@ -722,6 +717,7 @@ def __init__(
722717 self .approximate = approximate
723718 self .tau = tau
724719
720+ @subset_transform
725721 def transform (self , X : Tensor ) -> Tensor :
726722 r"""Round the inputs.
727723
@@ -731,14 +727,7 @@ def transform(self, X: Tensor) -> Tensor:
731727 Returns:
732728 A `batch_shape x n x d`-dim tensor of rounded inputs.
733729 """
734- X_rounded = X .clone ()
735- X_int = X_rounded [..., self .indices ]
736- if self .approximate :
737- X_int = approximate_round (X_int , tau = self .tau )
738- else :
739- X_int = X_int .round ()
740- X_rounded [..., self .indices ] = X_int
741- return X_rounded
730+ return approximate_round (X , tau = self .tau ) if self .approximate else X .round ()
742731
743732 def equals (self , other : InputTransform ) -> bool :
744733 r"""Check if another input transform is equivalent.
@@ -787,6 +776,7 @@ def __init__(
787776 self .transform_on_fantasize = transform_on_fantasize
788777 self .reverse = reverse
789778
779+ @subset_transform
790780 def _transform (self , X : Tensor ) -> Tensor :
791781 r"""Log transform the inputs.
792782
@@ -796,10 +786,9 @@ def _transform(self, X: Tensor) -> Tensor:
796786 Returns:
797787 A `batch_shape x n x d`-dim tensor of transformed inputs.
798788 """
799- X_new = X .clone ()
800- X_new [..., self .indices ] = X_new [..., self .indices ].log10 ()
801- return X_new
789+ return X .log10 ()
802790
791+ @subset_transform
803792 def _untransform (self , X : Tensor ) -> Tensor :
804793 r"""Reverse the log transformation.
805794
@@ -809,9 +798,7 @@ def _untransform(self, X: Tensor) -> Tensor:
809798 Returns:
810799 A `batch_shape x n x d`-dim tensor of un-normalized inputs.
811800 """
812- X_new = X .clone ()
813- X_new [..., self .indices ] = 10.0 ** X_new [..., self .indices ]
814- return X_new
801+ return 10.0 ** X
815802
816803
817804class Warp (ReversibleInputTransform , GPyTorchModule ):
@@ -915,6 +902,7 @@ def _set_concentration(self, i: int, value: Union[float, Tensor]) -> None:
915902 value = torch .as_tensor (value ).to (self .concentration0 )
916903 self .initialize (** {f"concentration{ i } " : value })
917904
905+ @subset_transform
918906 def _transform (self , X : Tensor ) -> Tensor :
919907 r"""Warp the inputs through the Kumaraswamy CDF.
920908
@@ -927,20 +915,16 @@ def _transform(self, X: Tensor) -> Tensor:
927915 A `input_batch_shape x (batch_shape) x n x d`-dim tensor of transformed
928916 inputs.
929917 """
930- X_tf = expand_and_copy_tensor (X = X , batch_shape = self .batch_shape )
931- k = Kumaraswamy (
932- concentration1 = self .concentration1 , concentration0 = self .concentration0
933- )
934- # normalize to [eps, 1-eps]
935- X_tf [..., self .indices ] = k .cdf (
918+ # normalize to [eps, 1-eps], IDEA: could use Normalize and ChainedTransform.
919+ return self ._k .cdf (
936920 torch .clamp (
937- X_tf [..., self . indices ] * self ._X_range + self ._X_min ,
921+ X * self ._X_range + self ._X_min ,
938922 self ._X_min ,
939923 1.0 - self ._X_min ,
940924 )
941925 )
942- return X_tf
943926
927+ @subset_transform
944928 def _untransform (self , X : Tensor ) -> Tensor :
945929 r"""Warp the inputs through the Kumaraswamy inverse CDF.
946930
@@ -957,15 +941,16 @@ def _untransform(self, X: Tensor) -> Tensor:
957941 "The right most batch dims of X must match self.batch_shape: "
958942 f"({ self .batch_shape } )."
959943 )
960- X_tf = X .clone ()
961- k = Kumaraswamy (
962- concentration1 = self .concentration1 , concentration0 = self .concentration0
963- )
964944 # unnormalize from [eps, 1-eps] to [0,1]
965- X_tf [..., self .indices ] = (
966- (k .icdf (X_tf [..., self .indices ]) - self ._X_min ) / self ._X_range
967- ).clamp (0.0 , 1.0 )
968- return X_tf
945+ return ((self ._k .icdf (X ) - self ._X_min ) / self ._X_range ).clamp (0.0 , 1.0 )
946+
947+ @property
948+ def _k (self ) -> Kumaraswamy :
949+ """Returns a Kumaraswamy distribution with the concentration parameters."""
950+ return Kumaraswamy (
951+ concentration1 = self .concentration1 ,
952+ concentration0 = self .concentration0 ,
953+ )
969954
970955
971956class AppendFeatures (InputTransform , Module ):
@@ -1225,6 +1210,7 @@ def __init__(
12251210 self ,
12261211 perturbation_set : Union [Tensor , Callable [[Tensor ], Tensor ]],
12271212 bounds : Optional [Tensor ] = None ,
1213+ indices : Optional [List [int ]] = None ,
12281214 multiplicative : bool = False ,
12291215 transform_on_train : bool = False ,
12301216 transform_on_eval : bool = True ,
@@ -1240,6 +1226,10 @@ def __init__(
12401226 bounds: A `2 x d`-dim tensor of lower and upper bounds for each
12411227 column of the input. If given, the perturbed inputs will be
12421228 clamped to these bounds.
1229+ indices: A list of indices specifying a subset of inputs on which to apply
1230+ the transform. Note that `len(indices)` should be equal to the second
1231+ dimension of `perturbation_set` and `bounds`. The dimensionality of
1232+ the input `X.shape[-1]` can be larger if we only transform a subset.
12431233 multiplicative: A boolean indicating whether the input perturbations
12441234 are additive or multiplicative. If True, inputs will be multiplied
12451235 with the perturbations.
@@ -1270,6 +1260,8 @@ def __init__(
12701260 self .register_buffer ("bounds" , bounds )
12711261 else :
12721262 self .bounds = None
1263+ self .register_buffer ("_perturbations" , None )
1264+ self .indices = indices
12731265 self .multiplicative = multiplicative
12741266 self .transform_on_train = transform_on_train
12751267 self .transform_on_eval = transform_on_eval
@@ -1294,21 +1286,36 @@ def transform(self, X: Tensor) -> Tensor:
12941286 Returns:
12951287 A `batch_shape x (q * n_p) x d`-dim tensor of perturbed inputs.
12961288 """
1297- if isinstance (self .perturbation_set , Tensor ):
1298- perturbations = self .perturbation_set
1299- else :
1300- perturbations = self .perturbation_set (X )
1301- expanded_X = X .unsqueeze (dim = - 2 ).expand (
1302- * X .shape [:- 1 ], perturbations .shape [- 2 ], - 1
1303- )
1304- expanded_perturbations = perturbations .expand (* expanded_X .shape [:- 1 ], - 1 )
1305- if self .multiplicative :
1306- perturbed_inputs = expanded_X * expanded_perturbations
1307- else :
1308- perturbed_inputs = expanded_X + expanded_perturbations
1309- perturbed_inputs = perturbed_inputs .reshape (* X .shape [:- 2 ], - 1 , X .shape [- 1 ])
1289+ # NOTE: If we had access to n_p without evaluating _perturbations when the
1290+ # perturbation_set is a function, we could move this into `_transform`.
1291+ # Further, we could remove the two `transpose` calls below if one were
1292+ # willing to accept a different ordering of the transformed output.
1293+ self ._perturbations = self ._expanded_perturbations (X )
1294+ # make space for n_p dimension, switch n_p with n after transform, and flatten.
1295+ return self ._transform (X .unsqueeze (- 3 )).transpose (- 3 , - 2 ).flatten (- 3 , - 2 )
1296+
1297+ @subset_transform
1298+ def _transform (self , X : Tensor ):
1299+ p = self ._perturbations
1300+ Y = X * p if self .multiplicative else X + p
13101301 if self .bounds is not None :
1311- perturbed_inputs = torch .maximum (
1312- torch .minimum (perturbed_inputs , self .bounds [1 ]), self .bounds [0 ]
1313- )
1314- return perturbed_inputs
1302+ return torch .maximum (torch .minimum (Y , self .bounds [1 ]), self .bounds [0 ])
1303+ return Y
1304+
1305+ @property
1306+ def batch_shape (self ):
1307+ """Returns a shape tuple such that `subset_transform` pre-allocates
1308+ a (b x n_p x n x d) - dim tensor, where `b` is the batch shape of the
1309+ input `X` of the transform and `n_p` is the number of perturbations.
1310+ NOTE: this function is dependent on calling `_expanded_perturbations(X)`
1311+ because `n_p` is inaccessible otherwise if `perturbation_set` is a function.
1312+ """
1313+ return self ._perturbations .shape [:- 2 ]
1314+
1315+ def _expanded_perturbations (self , X : Tensor ) -> Tensor :
1316+ p = self .perturbation_set
1317+ if isinstance (p , Tensor ):
1318+ p = p .expand (X .shape [- 2 ], * p .shape ) # p is batch_shape x n x n_p x d
1319+ else :
1320+ p = p (X ) if self .indices is None else p (X [..., self .indices ])
1321+ return p .transpose (- 3 , - 2 ) # p is batch_shape x n_p x n x d
0 commit comments