|
31 | 31 | from torch import nn, Tensor
|
32 | 32 | from torch.distributions import Kumaraswamy
|
33 | 33 | from torch.nn import Module, ModuleDict
|
| 34 | +from torch.nn.functional import one_hot |
34 | 35 |
|
35 | 36 |
|
36 | 37 | class InputTransform(ABC):
|
@@ -1358,3 +1359,133 @@ def _expanded_perturbations(self, X: Tensor) -> Tensor:
|
1358 | 1359 | else:
|
1359 | 1360 | p = p(X) if self.indices is None else p(X[..., self.indices])
|
1360 | 1361 | return p.transpose(-3, -2) # p is batch_shape x n_p x n x d
|
| 1362 | + |
| 1363 | + |
| 1364 | +class OneHotToNumeric(InputTransform, Module): |
| 1365 | + r"""Transform categorical parameters from a one-hot to a numeric representation. |
| 1366 | +
|
| 1367 | + This assumes that the categoricals are the trailing dimensions. |
| 1368 | + """ |
| 1369 | + |
| 1370 | + def __init__( |
| 1371 | + self, |
| 1372 | + dim: int, |
| 1373 | + categorical_features: Optional[Dict[int, int]] = None, |
| 1374 | + transform_on_train: bool = False, |
| 1375 | + transform_on_eval: bool = True, |
| 1376 | + transform_on_fantasize: bool = False, |
| 1377 | + ) -> None: |
| 1378 | + r"""Initialize. |
| 1379 | +
|
| 1380 | + Args: |
| 1381 | + dim: The dimension of the one-hot-encoded input. |
| 1382 | + categorical_features: A dictionary mapping the starting index of each |
| 1383 | + categorical feature to its cardinality. This assumes that categoricals |
| 1384 | + are one-hot encoded. |
| 1385 | + transform_on_train: A boolean indicating whether to apply the |
| 1386 | + transforms in train() mode. Default: False. |
| 1387 | + transform_on_eval: A boolean indicating whether to apply the |
| 1388 | + transform in eval() mode. Default: True. |
| 1389 | + transform_on_fantasize: A boolean indicating whether to apply the |
| 1390 | + transform when called from within a `fantasize` call. Default: False. |
| 1391 | +
|
| 1392 | + Returns: |
| 1393 | + A `batch_shape x n x d'`-dim tensor of where the one-hot encoded |
| 1394 | + categoricals are transformed to integer representation. |
| 1395 | + """ |
| 1396 | + super().__init__() |
| 1397 | + self.transform_on_train = transform_on_train |
| 1398 | + self.transform_on_eval = transform_on_eval |
| 1399 | + self.transform_on_fantasize = transform_on_fantasize |
| 1400 | + categorical_features = categorical_features or {} |
| 1401 | + # sort by starting index |
| 1402 | + self.categorical_features = OrderedDict( |
| 1403 | + sorted(categorical_features.items(), key=lambda x: x[0]) |
| 1404 | + ) |
| 1405 | + if len(self.categorical_features) > 0: |
| 1406 | + self.categorical_start_idx = min(self.categorical_features.keys()) |
| 1407 | + # check that the trailing dimensions are categoricals |
| 1408 | + end = self.categorical_start_idx |
| 1409 | + err_msg = ( |
| 1410 | + f"{self.__class__.__name__} requires that the categorical " |
| 1411 | + "parameters are the rightmost elements." |
| 1412 | + ) |
| 1413 | + for start, card in self.categorical_features.items(): |
| 1414 | + # the end of one one-hot representation should be followed |
| 1415 | + # by the start of the next |
| 1416 | + if end != start: |
| 1417 | + raise ValueError(err_msg) |
| 1418 | + end = start + card |
| 1419 | + if end != dim: |
| 1420 | + # check end |
| 1421 | + raise ValueError(err_msg) |
| 1422 | + # the numeric representation dimension is the total number of parameters |
| 1423 | + # (continuous, integer, and categorical) |
| 1424 | + self.numeric_dim = self.categorical_start_idx + len(categorical_features) |
| 1425 | + |
| 1426 | + def transform(self, X: Tensor) -> Tensor: |
| 1427 | + r"""Transform the categorical inputs into integer representation. |
| 1428 | +
|
| 1429 | + Args: |
| 1430 | + X: A `batch_shape x n x d`-dim tensor of inputs. |
| 1431 | +
|
| 1432 | + Returns: |
| 1433 | + A `batch_shape x n x d'`-dim tensor of where the one-hot encoded |
| 1434 | + categoricals are transformed to integer representation. |
| 1435 | + """ |
| 1436 | + if len(self.categorical_features) > 0: |
| 1437 | + X_numeric = X[..., : self.numeric_dim].clone() |
| 1438 | + idx = self.categorical_start_idx |
| 1439 | + for start, card in self.categorical_features.items(): |
| 1440 | + X_numeric[..., idx] = X[..., start : start + card].argmax(dim=-1) |
| 1441 | + idx += 1 |
| 1442 | + return X_numeric |
| 1443 | + return X |
| 1444 | + |
| 1445 | + def untransform(self, X: Tensor) -> Tensor: |
| 1446 | + r"""Transform the categoricals from integer representation to one-hot. |
| 1447 | +
|
| 1448 | + Args: |
| 1449 | + X: A `batch_shape x n x d'`-dim tensor of transformed inputs, where |
| 1450 | + the categoricals are represented as integers. |
| 1451 | +
|
| 1452 | + Returns: |
| 1453 | + A `batch_shape x n x d`-dim tensor of inputs, where the categoricals |
| 1454 | + have been transformed to one-hot representation. |
| 1455 | + """ |
| 1456 | + if len(self.categorical_features) > 0: |
| 1457 | + self.numeric_dim |
| 1458 | + one_hot_categoricals = [ |
| 1459 | + # note that self.categorical_features is sorted by the starting index |
| 1460 | + # in one-hot representation |
| 1461 | + one_hot( |
| 1462 | + X[..., idx - len(self.categorical_features)].long(), |
| 1463 | + num_classes=cardinality, |
| 1464 | + ) |
| 1465 | + for idx, cardinality in enumerate(self.categorical_features.values()) |
| 1466 | + ] |
| 1467 | + X = torch.cat( |
| 1468 | + [ |
| 1469 | + X[..., : self.categorical_start_idx], |
| 1470 | + *one_hot_categoricals, |
| 1471 | + ], |
| 1472 | + dim=-1, |
| 1473 | + ) |
| 1474 | + return X |
| 1475 | + |
| 1476 | + def equals(self, other: InputTransform) -> bool: |
| 1477 | + r"""Check if another input transform is equivalent. |
| 1478 | +
|
| 1479 | + Args: |
| 1480 | + other: Another input transform. |
| 1481 | +
|
| 1482 | + Returns: |
| 1483 | + A boolean indicating if the other transform is equivalent. |
| 1484 | + """ |
| 1485 | + return ( |
| 1486 | + type(self) == type(other) |
| 1487 | + and (self.transform_on_train == other.transform_on_train) |
| 1488 | + and (self.transform_on_eval == other.transform_on_eval) |
| 1489 | + and (self.transform_on_fantasize == other.transform_on_fantasize) |
| 1490 | + and self.categorical_features == other.categorical_features |
| 1491 | + ) |
0 commit comments