|
1 | | -from typing import Any, Dict |
| 1 | +from typing import Any, Dict, Optional |
2 | 2 |
|
3 | 3 | import torch |
4 | 4 | from torch import Tensor |
| 5 | +from torch.distributions import Categorical |
5 | 6 |
|
6 | | -from ..constraints import Positive |
| 7 | +from ..constraints import Interval, Positive |
7 | 8 | from ..distributions import MultivariateNormal |
8 | | -from .likelihood import Likelihood |
| 9 | +from ..priors import Prior |
| 10 | +from .likelihood import _OneDimensionalLikelihood |
| 11 | + |
9 | 12 |
|
10 | 13 | def inv_probit(x, jitter=1e-3): |
11 | 14 | """ |
12 | 15 | Inverse probit function (standard normal CDF) with jitter for numerical stability. |
13 | | - |
| 16 | +
|
14 | 17 | Args: |
15 | 18 | x: Input tensor |
16 | 19 | jitter: Small constant to ensure outputs are strictly between 0 and 1 |
17 | | - |
| 20 | +
|
18 | 21 | Returns: |
19 | 22 | Probabilities between jitter and 1-jitter |
20 | 23 | """ |
21 | 24 | return 0.5 * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0)))) * (1 - 2 * jitter) + jitter |
22 | 25 |
|
23 | | -class OrdinalLikelihood(Likelihood): |
24 | | - def __init__(self, bin_edges): |
25 | | - """ |
26 | | - An ordinal likelihood for regressing over ordinal data. |
27 | 26 |
|
28 | | - The data are integer values from 0 to k, and the user must specify (k-1) |
29 | | - 'bin edges' which define the points at which the labels switch. Let the bin |
30 | | - edges be [a₀, a₁, ... aₖ₋₁], then the likelihood is |
| 27 | +class OrdinalLikelihood(_OneDimensionalLikelihood): |
| 28 | + r""" |
| 29 | + An ordinal likelihood for regressing over ordinal data. |
| 30 | +
|
| 31 | + The data are integer values from :math:`0` to :math:`k`, and the user must specify :math:`(k-1)` |
| 32 | + 'bin edges' which define the points at which the labels switch. Let the bin |
| 33 | + edges be :math:`[a_0, a_1, ... a_{k-1}]`, then the likelihood is |
| 34 | +
|
| 35 | + .. math:: |
| 36 | + p(Y=0|F) &= \Phi((a_0 - F) / \sigma) |
| 37 | +
|
| 38 | + p(Y=1|F) &= \Phi((a_1 - F) / \sigma) - \Phi((a_0 - F) / \sigma) |
| 39 | +
|
| 40 | + p(Y=2|F) &= \Phi((a_2 - F) / \sigma) - \Phi((a_1 - F) / \sigma) |
31 | 41 |
|
32 | | - p(Y=0|F) = ɸ((a₀ - F) / σ) |
33 | | - p(Y=1|F) = ɸ((a₁ - F) / σ) - ɸ((a₀ - F) / σ) |
34 | | - p(Y=2|F) = ɸ((a₂ - F) / σ) - ɸ((a₁ - F) / σ) |
35 | 42 | ... |
36 | | - p(Y=K|F) = 1 - ɸ((aₖ₋₁ - F) / σ) |
37 | 43 |
|
38 | | - where ɸ is the cumulative density function of a Gaussian (the inverse probit |
39 | | - function) and σ is a parameter to be learned. |
| 44 | + p(Y=K|F) &= 1 - \Phi((a_{k-1} - F) / \sigma) |
| 45 | +
|
| 46 | + where :math:`\Phi` is the cumulative density function of a Gaussian (the inverse probit |
| 47 | + function) and :math:`\sigma` is a parameter to be learned. |
| 48 | +
|
| 49 | + From Chu et Ghahramani, Journal of Machine Learning Research, 2005 |
| 50 | + [https://www.jmlr.org/papers/volume6/chu05a/chu05a.pdf]. |
| 51 | +
|
| 52 | + :param bin_edges: A tensor of shape :math:`(k-1)` containing the bin edges. |
| 53 | + :param batch_shape: The batch shape of the learned sigma parameter (default: []). |
| 54 | + :param sigma_prior: Prior for sigma parameter :math:`\sigma`. |
| 55 | + :param sigma_constraint: Constraint for sigma parameter :math:`\sigma`. |
40 | 56 |
|
41 | | - A reference is :cite:t:`chu2005gaussian`. |
| 57 | + :ivar torch.Tensor bin_edges: :math:`\{a_i\}_{i=0}^{k-1}` bin edges |
| 58 | + :ivar torch.Tensor sigma: :math:`\sigma` parameter (scale) |
| 59 | + """ |
42 | 60 |
|
43 | | - :param bin_edges: A tensor of shape (k-1) containing the bin edges. |
44 | | - """ |
| 61 | + def __init__( |
| 62 | + self, |
| 63 | + bin_edges: Tensor, |
| 64 | + batch_shape: torch.Size = torch.Size([]), |
| 65 | + sigma_prior: Optional[Prior] = None, |
| 66 | + sigma_constraint: Optional[Interval] = None, |
| 67 | + ) -> None: |
45 | 68 | super().__init__() |
| 69 | + |
46 | 70 | self.num_bins = len(bin_edges) + 1 |
| 71 | + self.register_parameter("bin_edges", torch.nn.Parameter(bin_edges, requires_grad=False)) |
| 72 | + |
| 73 | + if sigma_constraint is None: |
| 74 | + sigma_constraint = Positive() |
| 75 | + |
| 76 | + self.raw_sigma = torch.nn.Parameter(torch.ones(*batch_shape, 1)) |
| 77 | + if sigma_prior is not None: |
| 78 | + self.register_prior("sigma_prior", sigma_prior, lambda m: m.sigma, lambda m, v: m._set_sigma(v)) |
| 79 | + |
| 80 | + self.register_constraint("raw_sigma", sigma_constraint) |
47 | 81 |
|
48 | | - self.register_parameter('bin_edges', torch.nn.Parameter(bin_edges, requires_grad=False)) |
49 | | - self.register_parameter('sigma', torch.nn.Parameter(torch.tensor(1.0))) |
50 | | - self.register_constraint('sigma', Positive()) |
| 82 | + @property |
| 83 | + def sigma(self) -> Tensor: |
| 84 | + return self.raw_sigma_constraint.transform(self.raw_sigma) |
51 | 85 |
|
52 | | - def forward(self, function_samples: Tensor, *args: Any, data: Dict[str, Tensor] = {}, **kwargs: Any): |
| 86 | + @sigma.setter |
| 87 | + def sigma(self, value: Tensor) -> None: |
| 88 | + self._set_sigma(value) |
| 89 | + |
| 90 | + def _set_sigma(self, value: Tensor) -> None: |
| 91 | + if not torch.is_tensor(value): |
| 92 | + value = torch.as_tensor(value).to(self.raw_sigma) |
| 93 | + self.initialize(raw_sigma=self.raw_sigma_constraint.inverse_transform(value)) |
| 94 | + |
| 95 | + def forward(self, function_samples: Tensor, *args: Any, data: Dict[str, Tensor] = {}, **kwargs: Any) -> Categorical: |
53 | 96 | if isinstance(function_samples, MultivariateNormal): |
54 | 97 | function_samples = function_samples.sample() |
55 | | - |
| 98 | + |
56 | 99 | # Compute scaled bin edges |
57 | 100 | scaled_edges = self.bin_edges / self.sigma |
58 | 101 | scaled_edges_left = torch.cat([scaled_edges, torch.tensor([torch.inf], device=scaled_edges.device)], dim=-1) |
59 | 102 | scaled_edges_right = torch.cat([torch.tensor([-torch.inf], device=scaled_edges.device), scaled_edges]) |
60 | | - |
| 103 | + |
61 | 104 | # Calculate cumulative probabilities using standard normal CDF (probit function) |
62 | | - # These represent P(Y ≤ k | F) |
63 | 105 | function_samples = function_samples.unsqueeze(-1) |
64 | | - scaled_edges_left = scaled_edges_left.reshape(1, 1, -1) |
65 | | - scaled_edges_right = scaled_edges_right.reshape(1, 1, -1) |
66 | | - probs = inv_probit(scaled_edges_left - function_samples / self.sigma) - inv_probit(scaled_edges_right - function_samples / self.sigma) |
67 | | - |
68 | | - return torch.distributions.Categorical(probs=probs) |
| 106 | + scaled_edges_left = scaled_edges_left.reshape(1, -1) |
| 107 | + scaled_edges_right = scaled_edges_right.reshape(1, -1) |
| 108 | + probs = inv_probit(scaled_edges_left - function_samples / self.sigma) - inv_probit( |
| 109 | + scaled_edges_right - function_samples / self.sigma |
| 110 | + ) |
| 111 | + |
| 112 | + return Categorical(probs=probs) |
0 commit comments