Skip to content

Commit 3d5f32d

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add one hot to numeric input transform (#1517)
Summary: Pull Request resolved: #1517 see title Differential Revision: https://internalfb.com/D41482322 fbshipit-source-id: db4d684f6cbbe8af5ae6b257a949da0610871bd8
1 parent 0fb00ef commit 3d5f32d

File tree

2 files changed

+211
-0
lines changed

2 files changed

+211
-0
lines changed

botorch/models/transforms/input.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch import nn, Tensor
3232
from torch.distributions import Kumaraswamy
3333
from torch.nn import Module, ModuleDict
34+
from torch.nn.functional import one_hot
3435

3536

3637
class InputTransform(ABC):
@@ -1358,3 +1359,133 @@ def _expanded_perturbations(self, X: Tensor) -> Tensor:
13581359
else:
13591360
p = p(X) if self.indices is None else p(X[..., self.indices])
13601361
return p.transpose(-3, -2) # p is batch_shape x n_p x n x d
1362+
1363+
1364+
class OneHotToNumeric(InputTransform, Module):
1365+
r"""Transform categorical parameters from a one-hot to a numeric representation.
1366+
1367+
This assumes that the categoricals are the trailing dimensions.
1368+
"""
1369+
1370+
def __init__(
1371+
self,
1372+
dim: int,
1373+
categorical_features: Optional[Dict[int, int]] = None,
1374+
transform_on_train: bool = False,
1375+
transform_on_eval: bool = True,
1376+
transform_on_fantasize: bool = False,
1377+
) -> None:
1378+
r"""Initialize.
1379+
1380+
Args:
1381+
dim: The dimension of the one-hot-encoded input.
1382+
categorical_features: A dictionary mapping the starting index of each
1383+
categorical feature to its cardinality. This assumes that categoricals
1384+
are one-hot encoded.
1385+
transform_on_train: A boolean indicating whether to apply the
1386+
transforms in train() mode. Default: False.
1387+
transform_on_eval: A boolean indicating whether to apply the
1388+
transform in eval() mode. Default: True.
1389+
transform_on_fantasize: A boolean indicating whether to apply the
1390+
transform when called from within a `fantasize` call. Default: False.
1391+
1392+
Returns:
1393+
A `batch_shape x n x d'`-dim tensor of where the one-hot encoded
1394+
categoricals are transformed to integer representation.
1395+
"""
1396+
super().__init__()
1397+
self.transform_on_train = transform_on_train
1398+
self.transform_on_eval = transform_on_eval
1399+
self.transform_on_fantasize = transform_on_fantasize
1400+
categorical_features = categorical_features or {}
1401+
# sort by starting index
1402+
self.categorical_features = OrderedDict(
1403+
sorted(categorical_features.items(), key=lambda x: x[0])
1404+
)
1405+
if len(self.categorical_features) > 0:
1406+
self.categorical_start_idx = min(self.categorical_features.keys())
1407+
# check that the trailing dimensions are categoricals
1408+
end = self.categorical_start_idx
1409+
err_msg = (
1410+
f"{self.__class__.__name__} requires that the categorical "
1411+
"parameters are the rightmost elements."
1412+
)
1413+
for start, card in self.categorical_features.items():
1414+
# the end of one one-hot representation should be followed
1415+
# by the start of the next
1416+
if end != start:
1417+
raise ValueError(err_msg)
1418+
end = start + card
1419+
if end != dim:
1420+
# check end
1421+
raise ValueError(err_msg)
1422+
# the numeric representation dimension is the total number of parameters
1423+
# (continuous, integer, and categorical)
1424+
self.numeric_dim = self.categorical_start_idx + len(categorical_features)
1425+
1426+
def transform(self, X: Tensor) -> Tensor:
1427+
r"""Transform the categorical inputs into integer representation.
1428+
1429+
Args:
1430+
X: A `batch_shape x n x d`-dim tensor of inputs.
1431+
1432+
Returns:
1433+
A `batch_shape x n x d'`-dim tensor of where the one-hot encoded
1434+
categoricals are transformed to integer representation.
1435+
"""
1436+
if len(self.categorical_features) > 0:
1437+
X_numeric = X[..., : self.numeric_dim].clone()
1438+
idx = self.categorical_start_idx
1439+
for start, card in self.categorical_features.items():
1440+
X_numeric[..., idx] = X[..., start : start + card].argmax(dim=-1)
1441+
idx += 1
1442+
return X_numeric
1443+
return X
1444+
1445+
def untransform(self, X: Tensor) -> Tensor:
1446+
r"""Transform the categoricals from integer representation to one-hot.
1447+
1448+
Args:
1449+
X: A `batch_shape x n x d'`-dim tensor of transformed inputs, where
1450+
the categoricals are represented as integers.
1451+
1452+
Returns:
1453+
A `batch_shape x n x d`-dim tensor of inputs, where the categoricals
1454+
have been transformed to one-hot representation.
1455+
"""
1456+
if len(self.categorical_features) > 0:
1457+
self.numeric_dim
1458+
one_hot_categoricals = [
1459+
# note that self.categorical_features is sorted by the starting index
1460+
# in one-hot representation
1461+
one_hot(
1462+
X[..., idx - len(self.categorical_features)].long(),
1463+
num_classes=cardinality,
1464+
)
1465+
for idx, cardinality in enumerate(self.categorical_features.values())
1466+
]
1467+
X = torch.cat(
1468+
[
1469+
X[..., : self.categorical_start_idx],
1470+
*one_hot_categoricals,
1471+
],
1472+
dim=-1,
1473+
)
1474+
return X
1475+
1476+
def equals(self, other: InputTransform) -> bool:
1477+
r"""Check if another input transform is equivalent.
1478+
1479+
Args:
1480+
other: Another input transform.
1481+
1482+
Returns:
1483+
A boolean indicating if the other transform is equivalent.
1484+
"""
1485+
return (
1486+
type(self) == type(other)
1487+
and (self.transform_on_train == other.transform_on_train)
1488+
and (self.transform_on_eval == other.transform_on_eval)
1489+
and (self.transform_on_fantasize == other.transform_on_fantasize)
1490+
and self.categorical_features == other.categorical_features
1491+
)

test/models/transforms/test_input.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
InputTransform,
2222
Log10,
2323
Normalize,
24+
OneHotToNumeric,
2425
Round,
2526
Warp,
2627
)
@@ -889,6 +890,85 @@ def test_warp_transform(self):
889890
warp_tf._set_concentration(i=1, value=3.0)
890891
self.assertTrue((warp_tf.concentration1 == 3.0).all())
891892

893+
def test_one_hot_to_numeric(self):
894+
dim = 8
895+
# test exception when categoricals are not the trailing dimensions
896+
categorical_features = {0: 2}
897+
with self.assertRaises(ValueError):
898+
OneHotToNumeric(dim=dim, categorical_features=categorical_features)
899+
# categoricals at start and end of X but not in between
900+
categorical_features = {0: 3, 6: 2}
901+
with self.assertRaises(ValueError):
902+
OneHotToNumeric(dim=dim, categorical_features=categorical_features)
903+
for dtype in (torch.float, torch.double):
904+
categorical_features = {6: 2, 3: 3}
905+
tf = OneHotToNumeric(dim=dim, categorical_features=categorical_features)
906+
tf.eval()
907+
self.assertEqual(tf.categorical_features, {3: 3, 6: 2})
908+
cat1_numeric = torch.randint(0, 3, (3,), device=self.device)
909+
cat1 = one_hot(cat1_numeric, num_classes=3)
910+
cat2_numeric = torch.randint(0, 2, (3,), device=self.device)
911+
cat2 = one_hot(cat2_numeric, num_classes=2)
912+
cont = torch.rand(3, 3, dtype=dtype, device=self.device)
913+
X = torch.cat([cont, cat1, cat2], dim=-1)
914+
# test forward
915+
X_numeric = tf(X)
916+
expected = torch.cat(
917+
[
918+
cont,
919+
cat1_numeric.view(-1, 1).to(cont),
920+
cat2_numeric.view(-1, 1).to(cont),
921+
],
922+
dim=-1,
923+
)
924+
self.assertTrue(torch.equal(X_numeric, expected))
925+
926+
# test untransform
927+
X2 = tf.untransform(X_numeric)
928+
self.assertTrue(torch.equal(X2, X))
929+
930+
# test no
931+
tf = OneHotToNumeric(dim=dim, categorical_features={})
932+
tf.eval()
933+
X_tf = tf(X)
934+
self.assertTrue(torch.equal(X, X_tf))
935+
X2 = tf(X_tf)
936+
self.assertTrue(torch.equal(X2, X_tf))
937+
938+
# test no transform on eval
939+
tf2 = OneHotToNumeric(
940+
dim=dim, categorical_features=categorical_features, transform_on_eval=False
941+
)
942+
tf2.eval()
943+
X_tf = tf2(X)
944+
self.assertTrue(torch.equal(X, X_tf))
945+
946+
# test no transform on train
947+
tf2 = OneHotToNumeric(
948+
dim=dim, categorical_features=categorical_features, transform_on_train=False
949+
)
950+
X_tf = tf2(X)
951+
self.assertTrue(torch.equal(X, X_tf))
952+
tf2.eval()
953+
X_tf = tf2(X)
954+
self.assertFalse(torch.equal(X, X_tf))
955+
956+
# test equals
957+
tf3 = OneHotToNumeric(
958+
dim=dim, categorical_features=categorical_features, transform_on_train=False
959+
)
960+
self.assertTrue(tf3.equals(tf2))
961+
# test different transform_on_train
962+
tf3 = OneHotToNumeric(
963+
dim=dim, categorical_features=categorical_features, transform_on_train=True
964+
)
965+
self.assertFalse(tf3.equals(tf2))
966+
# test categorical features
967+
tf3 = OneHotToNumeric(
968+
dim=dim, categorical_features={}, transform_on_train=False
969+
)
970+
self.assertFalse(tf3.equals(tf2))
971+
892972

893973
class TestAppendFeatures(BotorchTestCase):
894974
def test_append_features(self):

0 commit comments

Comments
 (0)