Skip to content

Commit 97626a9

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add support for categoricals in Round input transform and use STEs (#1516)
Summary: Pull Request resolved: #1516 see title Reviewed By: Balandat Differential Revision: D41477456 fbshipit-source-id: 21e500b887349f8164223fee696e46c506d61ab2
1 parent 9da6f22 commit 97626a9

File tree

2 files changed

+147
-33
lines changed

2 files changed

+147
-33
lines changed

botorch/models/transforms/input.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@
1212
rounding functions, and log transformations. The input transformation
1313
is typically part of a Model and applied within the model.forward()
1414
method.
15-
1615
"""
1716
from __future__ import annotations
1817

1918
from abc import ABC, abstractmethod
2019
from collections import OrderedDict
2120
from typing import Any, Callable, Dict, List, Optional, Union
21+
from warnings import warn
2222

2323
import torch
2424
from botorch.exceptions.errors import BotorchTensorDimensionError
2525
from botorch.models.transforms.utils import subset_transform
2626
from botorch.models.utils import fantasize
27-
from botorch.utils.rounding import approximate_round
27+
from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE
2828
from gpytorch import Module as GPyTorchModule
2929
from gpytorch.constraints import GreaterThan
3030
from gpytorch.priors import Prior
@@ -649,10 +649,10 @@ def _update_coefficients(self, X: Tensor) -> None:
649649

650650

651651
class Round(InputTransform, Module):
652-
r"""A rounding transformation for integer inputs.
652+
r"""A discretization transformation for discrete inputs.
653653
654-
This will typically be used in conjunction with normalization as
655-
follows:
654+
For integers, this will typically be used in conjunction
655+
with normalization as follows:
656656
657657
In eval() mode (i.e. after training), the inputs pass
658658
would typically be normalized to the unit cube (e.g. during candidate
@@ -667,19 +667,26 @@ class Round(InputTransform, Module):
667667
should be set to False, so that the raw inputs are rounded and then
668668
normalized to the unit cube.
669669
670-
This transformation uses differentiable approximate rounding by default.
671-
The rounding function is approximated with a piece-wise function where
672-
each piece is a hyperbolic tangent function.
670+
By default, the straight through estimators are used for the gradients as
671+
proposed in [Daulton2022bopr]_. This transformation supports differentiable
672+
approximate rounding (currently only for integers). The rounding function
673+
is approximated with a piece-wise function where each piece is a hyperbolic
674+
tangent function.
675+
676+
For categorical parameters, the input must be one-hot encoded.
673677
674678
Example:
679+
>>> bounds = torch.tensor([[0, 5], [0, 1], [0, 1]]).t()
680+
>>> integer_indices = [0]
681+
>>> categorical_features = {1: 2}
675682
>>> unnormalize_tf = Normalize(
676683
>>> d=d,
677684
>>> bounds=bounds,
678685
>>> transform_on_eval=True,
679686
>>> transform_on_train=True,
680687
>>> reverse=True,
681688
>>> )
682-
>>> round_tf = Round(integer_indices)
689+
>>> round_tf = Round(integer_indices, categorical_features)
683690
>>> normalize_tf = Normalize(d=d, bounds=bounds)
684691
>>> tf = ChainedInputTransform(
685692
>>> tf1=unnormalize_tf, tf2=round_tf, tf3=normalize_tf
@@ -688,46 +695,76 @@ class Round(InputTransform, Module):
688695

689696
def __init__(
690697
self,
691-
indices: List[int],
698+
integer_indices: Optional[List[int]] = None,
699+
categorical_features: Optional[Dict[int, int]] = None,
692700
transform_on_train: bool = True,
693701
transform_on_eval: bool = True,
694702
transform_on_fantasize: bool = True,
695-
approximate: bool = True,
703+
approximate: bool = False,
696704
tau: float = 1e-3,
705+
**kwargs,
697706
) -> None:
698707
r"""Initialize transform.
699708
700709
Args:
701-
indices: The indices of the integer inputs.
710+
integer_indices: The indices of the integer inputs.
711+
categorical_features: A dictionary mapping the starting index of each
712+
categorical feature to its cardinality. This assumes that categoricals
713+
are one-hot encoded.
702714
transform_on_train: A boolean indicating whether to apply the
703715
transforms in train() mode. Default: True.
704716
transform_on_eval: A boolean indicating whether to apply the
705717
transform in eval() mode. Default: True.
706718
transform_on_fantasize: A boolean indicating whether to apply the
707719
transform when called from within a `fantasize` call. Default: True.
708720
approximate: A boolean indicating whether approximate or exact
709-
rounding should be used. Default: approximate.
721+
rounding should be used. Default: False.
710722
tau: The temperature parameter for approximate rounding.
711723
"""
724+
indices = kwargs.get("indices")
725+
if indices is not None:
726+
warn(
727+
"`indices` is marked for deprecation in favor of `integer_indices`.",
728+
DeprecationWarning,
729+
)
730+
integer_indices = indices
731+
if approximate and categorical_features is not None:
732+
raise NotImplementedError
712733
super().__init__()
713734
self.transform_on_train = transform_on_train
714735
self.transform_on_eval = transform_on_eval
715736
self.transform_on_fantasize = transform_on_fantasize
716-
self.register_buffer("indices", torch.tensor(indices, dtype=torch.long))
737+
integer_indices = integer_indices or []
738+
self.register_buffer(
739+
"integer_indices", torch.tensor(integer_indices, dtype=torch.long)
740+
)
741+
self.categorical_features = categorical_features or {}
717742
self.approximate = approximate
718743
self.tau = tau
719744

720-
@subset_transform
721745
def transform(self, X: Tensor) -> Tensor:
722-
r"""Round the inputs.
746+
r"""Discretize the inputs.
723747
724748
Args:
725749
X: A `batch_shape x n x d`-dim tensor of inputs.
726750
727751
Returns:
728-
A `batch_shape x n x d`-dim tensor of rounded inputs.
752+
A `batch_shape x n x d`-dim tensor of discretized inputs.
729753
"""
730-
return approximate_round(X, tau=self.tau) if self.approximate else X.round()
754+
X_rounded = X.clone()
755+
# round integers
756+
X_int = X_rounded[..., self.integer_indices]
757+
if self.approximate:
758+
X_int = approximate_round(X_int, tau=self.tau)
759+
else:
760+
X_int = RoundSTE.apply(X_int)
761+
X_rounded[..., self.integer_indices] = X_int
762+
# discrete categoricals to the category with the largest value
763+
# in the continuous relaxation of the one-hot encoding
764+
for start, card in self.categorical_features.items():
765+
end = start + card
766+
X_rounded[..., start:end] = OneHotArgmaxSTE.apply(X[..., start:end])
767+
return X_rounded
731768

732769
def equals(self, other: InputTransform) -> bool:
733770
r"""Check if another input transform is equivalent.
@@ -740,6 +777,8 @@ def equals(self, other: InputTransform) -> bool:
740777
"""
741778
return (
742779
super().equals(other=other)
780+
and (self.integer_indices == other.integer_indices).all()
781+
and self.categorical_features == other.categorical_features
743782
and self.approximate == other.approximate
744783
and self.tau == other.tau
745784
)

test/models/transforms/test_input.py

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import itertools
8+
import warnings
89
from copy import deepcopy
910

1011
import torch
12+
from botorch import settings
1113
from botorch.exceptions.errors import BotorchTensorDimensionError
1214
from botorch.models.transforms.input import (
1315
AffineInputTransform,
@@ -29,6 +31,7 @@
2931
from torch import Tensor
3032
from torch.distributions import Kumaraswamy
3133
from torch.nn import Module
34+
from torch.nn.functional import one_hot
3235

3336

3437
def get_test_warp(indices, **kwargs):
@@ -534,19 +537,45 @@ def test_chained_input_transform(self):
534537
def test_round_transform(self):
535538
for dtype in (torch.float, torch.double):
536539
# basic init
537-
int_idcs = [0, 2]
538-
round_tf = Round(indices=[0, 2])
539-
self.assertEqual(round_tf.indices.tolist(), int_idcs)
540+
int_idcs = [0, 4]
541+
categorical_feats = {2: 2, 5: 3}
542+
# test deprecation warning
543+
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
544+
Round(indices=int_idcs)
545+
self.assertTrue(
546+
any(issubclass(w.category, DeprecationWarning) for w in ws)
547+
)
548+
round_tf = Round(
549+
integer_indices=int_idcs, categorical_features=categorical_feats
550+
)
551+
self.assertEqual(round_tf.integer_indices.tolist(), int_idcs)
552+
self.assertEqual(round_tf.categorical_features, categorical_feats)
540553
self.assertTrue(round_tf.training)
541-
self.assertTrue(round_tf.approximate)
554+
self.assertFalse(round_tf.approximate)
542555
self.assertEqual(round_tf.tau, 1e-3)
543556

544557
# basic usage
545-
for batch_shape, approx in itertools.product(
546-
(torch.Size(), torch.Size([3])), (False, True)
558+
for batch_shape, approx, categorical_features in itertools.product(
559+
(torch.Size(), torch.Size([3])),
560+
(False, True),
561+
(None, categorical_feats),
547562
):
548-
X = 5 * torch.rand(*batch_shape, 4, 3, device=self.device, dtype=dtype)
549-
round_tf = Round(indices=[0, 2], approximate=approx)
563+
X = torch.rand(*batch_shape, 4, 8, device=self.device, dtype=dtype)
564+
X[..., int_idcs] *= 5
565+
if categorical_features is not None and approx:
566+
with self.assertRaises(NotImplementedError):
567+
Round(
568+
integer_indices=int_idcs,
569+
categorical_features=categorical_features,
570+
approximate=approx,
571+
)
572+
continue
573+
round_tf = Round(
574+
integer_indices=int_idcs,
575+
categorical_features=categorical_features,
576+
approximate=approx,
577+
tau=1e-1,
578+
)
550579
X_rounded = round_tf(X)
551580
exact_rounded_X_ints = X[..., int_idcs].round()
552581
# check non-integers parameters are unchanged
@@ -560,17 +589,39 @@ def test_round_transform(self):
560589
<= (X[..., int_idcs] - exact_rounded_X_ints).abs()
561590
).all()
562591
)
592+
self.assertFalse(
593+
torch.equal(X_rounded[..., int_idcs], exact_rounded_X_ints)
594+
)
563595
else:
564-
# check that exact rounding behaves as expected
596+
# check that exact rounding behaves as expected for integers
565597
self.assertTrue(
566598
torch.equal(X_rounded[..., int_idcs], exact_rounded_X_ints)
567599
)
600+
if categorical_features is not None:
601+
# test that discretization works as expected for categoricals
602+
for start, card in categorical_features.items():
603+
end = start + card
604+
expected_categorical = one_hot(
605+
X[..., start:end].argmax(dim=-1), num_classes=card
606+
).to(X)
607+
self.assertTrue(
608+
torch.equal(
609+
X_rounded[..., start:end], expected_categorical
610+
)
611+
)
612+
# test that gradient information is passed via STE
613+
X2 = X.clone().requires_grad_(True)
614+
round_tf(X2).sum().backward()
615+
self.assertTrue(torch.equal(X2.grad, torch.ones_like(X2)))
568616
with self.assertRaises(NotImplementedError):
569617
round_tf.untransform(X_rounded)
570618

571619
# test no transform on eval
572620
round_tf = Round(
573-
indices=int_idcs, approximate=approx, transform_on_eval=False
621+
integer_indices=int_idcs,
622+
categorical_features=categorical_features,
623+
approximate=approx,
624+
transform_on_eval=False,
574625
)
575626
X_rounded = round_tf(X)
576627
self.assertFalse(torch.equal(X, X_rounded))
@@ -580,7 +631,10 @@ def test_round_transform(self):
580631

581632
# test no transform on train
582633
round_tf = Round(
583-
indices=int_idcs, approximate=approx, transform_on_train=False
634+
integer_indices=int_idcs,
635+
categorical_features=categorical_features,
636+
approximate=approx,
637+
transform_on_train=False,
584638
)
585639
X_rounded = round_tf(X)
586640
self.assertTrue(torch.equal(X, X_rounded))
@@ -590,27 +644,48 @@ def test_round_transform(self):
590644

591645
# test equals
592646
round_tf2 = Round(
593-
indices=int_idcs, approximate=approx, transform_on_train=False
647+
integer_indices=int_idcs,
648+
categorical_features=categorical_features,
649+
approximate=approx,
650+
transform_on_train=False,
594651
)
595652
self.assertTrue(round_tf.equals(round_tf2))
596653
# test different transform_on_train
597-
round_tf2 = Round(indices=int_idcs, approximate=approx)
654+
round_tf2 = Round(
655+
integer_indices=int_idcs,
656+
categorical_features=categorical_features,
657+
approximate=approx,
658+
)
598659
self.assertFalse(round_tf.equals(round_tf2))
599660
# test different approx
661+
round_tf = Round(
662+
integer_indices=int_idcs,
663+
)
600664
round_tf2 = Round(
601-
indices=int_idcs, approximate=not approx, transform_on_train=False
665+
integer_indices=int_idcs,
666+
approximate=not approx,
667+
transform_on_train=False,
602668
)
603669
self.assertFalse(round_tf.equals(round_tf2))
604670
# test different indices
671+
round_tf = Round(
672+
integer_indices=int_idcs,
673+
categorical_features=categorical_features,
674+
transform_on_train=False,
675+
)
605676
round_tf2 = Round(
606-
indices=[0, 1], approximate=approx, transform_on_train=False
677+
integer_indices=[0, 1],
678+
categorical_features=categorical_features,
679+
approximate=approx,
680+
transform_on_train=False,
607681
)
608682
self.assertFalse(round_tf.equals(round_tf2))
609683

610684
# test preprocess_transform
611685
round_tf.transform_on_train = False
612686
self.assertTrue(torch.equal(round_tf.preprocess_transform(X), X))
613687
round_tf.transform_on_train = True
688+
X_rounded = round_tf(X)
614689
self.assertTrue(
615690
torch.equal(round_tf.preprocess_transform(X), X_rounded)
616691
)

0 commit comments

Comments
 (0)