Skip to content

Commit 85ce7c7

Browse files
committed
Correctly specifying parameters and adding docs for ordinal likelihood
1 parent e0f8d60 commit 85ce7c7

File tree

3 files changed

+113
-33
lines changed

3 files changed

+113
-33
lines changed

docs/source/likelihoods.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ reduce the variance when computing approximate GP objective functions.
8080
.. autoclass:: StudentTLikelihood
8181
:members:
8282

83+
:hidden:`OrdinalLikelihood`
84+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
85+
86+
.. autoclass:: OrdinalLikelihood
87+
:members:
88+
8389

8490
Multi-Dimensional Likelihoods
8591
-----------------------------
Lines changed: 77 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,112 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, Optional
22

33
import torch
44
from torch import Tensor
5+
from torch.distributions import Categorical
56

6-
from ..constraints import Positive
7+
from ..constraints import Interval, Positive
78
from ..distributions import MultivariateNormal
8-
from .likelihood import Likelihood
9+
from ..priors import Prior
10+
from .likelihood import _OneDimensionalLikelihood
11+
912

1013
def inv_probit(x, jitter=1e-3):
1114
"""
1215
Inverse probit function (standard normal CDF) with jitter for numerical stability.
13-
16+
1417
Args:
1518
x: Input tensor
1619
jitter: Small constant to ensure outputs are strictly between 0 and 1
17-
20+
1821
Returns:
1922
Probabilities between jitter and 1-jitter
2023
"""
2124
return 0.5 * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0)))) * (1 - 2 * jitter) + jitter
2225

23-
class OrdinalLikelihood(Likelihood):
24-
def __init__(self, bin_edges):
25-
"""
26-
An ordinal likelihood for regressing over ordinal data.
2726

28-
The data are integer values from 0 to k, and the user must specify (k-1)
29-
'bin edges' which define the points at which the labels switch. Let the bin
30-
edges be [a₀, a₁, ... aₖ₋₁], then the likelihood is
27+
class OrdinalLikelihood(_OneDimensionalLikelihood):
28+
r"""
29+
An ordinal likelihood for regressing over ordinal data.
30+
31+
The data are integer values from :math:`0` to :math:`k`, and the user must specify :math:`(k-1)`
32+
'bin edges' which define the points at which the labels switch. Let the bin
33+
edges be :math:`[a_0, a_1, ... a_{k-1}]`, then the likelihood is
34+
35+
.. math::
36+
p(Y=0|F) &= \Phi((a_0 - F) / \sigma)
37+
38+
p(Y=1|F) &= \Phi((a_1 - F) / \sigma) - \Phi((a_0 - F) / \sigma)
39+
40+
p(Y=2|F) &= \Phi((a_2 - F) / \sigma) - \Phi((a_1 - F) / \sigma)
3141
32-
p(Y=0|F) = ɸ((a₀ - F) / σ)
33-
p(Y=1|F) = ɸ((a₁ - F) / σ) - ɸ((a₀ - F) / σ)
34-
p(Y=2|F) = ɸ((a₂ - F) / σ) - ɸ((a₁ - F) / σ)
3542
...
36-
p(Y=K|F) = 1 - ɸ((aₖ₋₁ - F) / σ)
3743
38-
where ɸ is the cumulative density function of a Gaussian (the inverse probit
39-
function) and σ is a parameter to be learned.
44+
p(Y=K|F) &= 1 - \Phi((a_{k-1} - F) / \sigma)
45+
46+
where :math:`\Phi` is the cumulative density function of a Gaussian (the inverse probit
47+
function) and :math:`\sigma` is a parameter to be learned.
48+
49+
From Chu et Ghahramani, Journal of Machine Learning Research, 2005
50+
[https://www.jmlr.org/papers/volume6/chu05a/chu05a.pdf].
51+
52+
:param bin_edges: A tensor of shape :math:`(k-1)` containing the bin edges.
53+
:param batch_shape: The batch shape of the learned sigma parameter (default: []).
54+
:param sigma_prior: Prior for sigma parameter :math:`\sigma`.
55+
:param sigma_constraint: Constraint for sigma parameter :math:`\sigma`.
4056
41-
A reference is :cite:t:`chu2005gaussian`.
57+
:ivar torch.Tensor bin_edges: :math:`\{a_i\}_{i=0}^{k-1}` bin edges
58+
:ivar torch.Tensor sigma: :math:`\sigma` parameter (scale)
59+
"""
4260

43-
:param bin_edges: A tensor of shape (k-1) containing the bin edges.
44-
"""
61+
def __init__(
62+
self,
63+
bin_edges: Tensor,
64+
batch_shape: torch.Size = torch.Size([]),
65+
sigma_prior: Optional[Prior] = None,
66+
sigma_constraint: Optional[Interval] = None,
67+
) -> None:
4568
super().__init__()
69+
4670
self.num_bins = len(bin_edges) + 1
71+
self.register_parameter("bin_edges", torch.nn.Parameter(bin_edges, requires_grad=False))
72+
73+
if sigma_constraint is None:
74+
sigma_constraint = Positive()
75+
76+
self.raw_sigma = torch.nn.Parameter(torch.ones(*batch_shape, 1))
77+
if sigma_prior is not None:
78+
self.register_prior("sigma_prior", sigma_prior, lambda m: m.sigma, lambda m, v: m._set_sigma(v))
79+
80+
self.register_constraint("raw_sigma", sigma_constraint)
4781

48-
self.register_parameter('bin_edges', torch.nn.Parameter(bin_edges, requires_grad=False))
49-
self.register_parameter('sigma', torch.nn.Parameter(torch.tensor(1.0)))
50-
self.register_constraint('sigma', Positive())
82+
@property
83+
def sigma(self) -> Tensor:
84+
return self.raw_sigma_constraint.transform(self.raw_sigma)
5185

52-
def forward(self, function_samples: Tensor, *args: Any, data: Dict[str, Tensor] = {}, **kwargs: Any):
86+
@sigma.setter
87+
def sigma(self, value: Tensor) -> None:
88+
self._set_sigma(value)
89+
90+
def _set_sigma(self, value: Tensor) -> None:
91+
if not torch.is_tensor(value):
92+
value = torch.as_tensor(value).to(self.raw_sigma)
93+
self.initialize(raw_sigma=self.raw_sigma_constraint.inverse_transform(value))
94+
95+
def forward(self, function_samples: Tensor, *args: Any, data: Dict[str, Tensor] = {}, **kwargs: Any) -> Categorical:
5396
if isinstance(function_samples, MultivariateNormal):
5497
function_samples = function_samples.sample()
55-
98+
5699
# Compute scaled bin edges
57100
scaled_edges = self.bin_edges / self.sigma
58101
scaled_edges_left = torch.cat([scaled_edges, torch.tensor([torch.inf], device=scaled_edges.device)], dim=-1)
59102
scaled_edges_right = torch.cat([torch.tensor([-torch.inf], device=scaled_edges.device), scaled_edges])
60-
103+
61104
# Calculate cumulative probabilities using standard normal CDF (probit function)
62-
# These represent P(Y ≤ k | F)
63105
function_samples = function_samples.unsqueeze(-1)
64-
scaled_edges_left = scaled_edges_left.reshape(1, 1, -1)
65-
scaled_edges_right = scaled_edges_right.reshape(1, 1, -1)
66-
probs = inv_probit(scaled_edges_left - function_samples / self.sigma) - inv_probit(scaled_edges_right - function_samples / self.sigma)
67-
68-
return torch.distributions.Categorical(probs=probs)
106+
scaled_edges_left = scaled_edges_left.reshape(1, -1)
107+
scaled_edges_right = scaled_edges_right.reshape(1, -1)
108+
probs = inv_probit(scaled_edges_left - function_samples / self.sigma) - inv_probit(
109+
scaled_edges_right - function_samples / self.sigma
110+
)
111+
112+
return Categorical(probs=probs)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
5+
import torch
6+
from torch.distributions import Distribution
7+
8+
from gpytorch.likelihoods import OrdinalLikelihood
9+
from gpytorch.test.base_likelihood_test_case import BaseLikelihoodTestCase
10+
11+
12+
class TestOrdinalLikelihood(BaseLikelihoodTestCase, unittest.TestCase):
13+
seed = 0
14+
15+
def create_likelihood(self):
16+
bin_edges = torch.tensor([-0.5, 0.5])
17+
return OrdinalLikelihood(bin_edges)
18+
19+
def _create_targets(self, batch_shape=torch.Size([])):
20+
return torch.distributions.Categorical(probs=torch.tensor([1 / 3, 1 / 3, 1 / 3])).sample(
21+
torch.Size([*batch_shape, 5])
22+
)
23+
24+
def _test_marginal(self, batch_shape):
25+
likelihood = self.create_likelihood()
26+
input = self._create_marginal_input(batch_shape)
27+
output = likelihood(input)
28+
29+
self.assertTrue(isinstance(output, Distribution))
30+
self.assertEqual(output.sample().shape[-len(batch_shape) - 1 :], torch.Size([*batch_shape, 5]))

0 commit comments

Comments
 (0)