Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/likelihoods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ reduce the variance when computing approximate GP objective functions.
.. autoclass:: StudentTLikelihood
:members:

:hidden:`OrdinalLikelihood`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: OrdinalLikelihood
:members:


Multi-Dimensional Likelihoods
-----------------------------
Expand Down
2 changes: 2 additions & 0 deletions gpytorch/likelihoods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .likelihood_list import LikelihoodList
from .multitask_gaussian_likelihood import _MultitaskGaussianLikelihoodBase, MultitaskGaussianLikelihood
from .noise_models import HeteroskedasticNoise
from .ordinal_likelihood import OrdinalLikelihood
from .softmax_likelihood import SoftmaxLikelihood
from .student_t_likelihood import StudentTLikelihood

Expand All @@ -32,6 +33,7 @@
"Likelihood",
"LikelihoodList",
"MultitaskGaussianLikelihood",
"OrdinalLikelihood",
"SoftmaxLikelihood",
"StudentTLikelihood",
]
108 changes: 108 additions & 0 deletions gpytorch/likelihoods/ordinal_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Any, Dict, Optional

import torch
from torch import Tensor
from torch.distributions import Categorical

from ..constraints import Interval, Positive
from ..priors import Prior
from .likelihood import _OneDimensionalLikelihood


def inv_probit(x, jitter=1e-3):
"""
Inverse probit function (standard normal CDF) with jitter for numerical stability.

Args:
x: Input tensor
jitter: Small constant to ensure outputs are strictly between 0 and 1

Returns:
Probabilities between jitter and 1-jitter
"""
return 0.5 * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0)))) * (1 - 2 * jitter) + jitter


class OrdinalLikelihood(_OneDimensionalLikelihood):
r"""
An ordinal likelihood for regressing over ordinal data.

The data are integer values from :math:`0` to :math:`k`, and the user must specify :math:`(k-1)`
'bin edges' which define the points at which the labels switch. Let the bin
edges be :math:`[a_0, a_1, ... a_{k-1}]`, then the likelihood is

.. math::
p(Y=0|F) &= \Phi((a_0 - F) / \sigma)

p(Y=1|F) &= \Phi((a_1 - F) / \sigma) - \Phi((a_0 - F) / \sigma)

p(Y=2|F) &= \Phi((a_2 - F) / \sigma) - \Phi((a_1 - F) / \sigma)

...

p(Y=K|F) &= 1 - \Phi((a_{k-1} - F) / \sigma)

where :math:`\Phi` is the cumulative density function of a Gaussian (the inverse probit
function) and :math:`\sigma` is a parameter to be learned.

From Chu et Ghahramani, Journal of Machine Learning Research, 2005
[https://www.jmlr.org/papers/volume6/chu05a/chu05a.pdf].

:param bin_edges: A tensor of shape :math:`(k-1)` containing the bin edges.
:param batch_shape: The batch shape of the learned sigma parameter (default: []).
:param sigma_prior: Prior for sigma parameter :math:`\sigma`.
:param sigma_constraint: Constraint for sigma parameter :math:`\sigma`.

:ivar torch.Tensor bin_edges: :math:`\{a_i\}_{i=0}^{k-1}` bin edges
:ivar torch.Tensor sigma: :math:`\sigma` parameter (scale)
"""

def __init__(
self,
bin_edges: Tensor,
batch_shape: torch.Size = torch.Size([]),
sigma_prior: Optional[Prior] = None,
sigma_constraint: Optional[Interval] = None,
) -> None:
super().__init__()

self.num_bins = len(bin_edges) + 1
self.register_parameter("bin_edges", torch.nn.Parameter(bin_edges, requires_grad=False))

if sigma_constraint is None:
sigma_constraint = Positive()

self.raw_sigma = torch.nn.Parameter(torch.ones(*batch_shape, 1))
if sigma_prior is not None:
self.register_prior("sigma_prior", sigma_prior, lambda m: m.sigma, lambda m, v: m._set_sigma(v))

self.register_constraint("raw_sigma", sigma_constraint)

@property
def sigma(self) -> Tensor:
return self.raw_sigma_constraint.transform(self.raw_sigma)

@sigma.setter
def sigma(self, value: Tensor) -> None:
self._set_sigma(value)

def _set_sigma(self, value: Tensor) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_sigma)
self.initialize(raw_sigma=self.raw_sigma_constraint.inverse_transform(value))

def forward(self, function_samples: Tensor, *args: Any, data: Dict[str, Tensor] = {}, **kwargs: Any) -> Categorical:
# Compute scaled bin edges
scaled_edges = self.bin_edges / self.sigma
scaled_edges_left = torch.cat([scaled_edges, torch.tensor([torch.inf], device=scaled_edges.device)], dim=-1)
scaled_edges_right = torch.cat([torch.tensor([-torch.inf], device=scaled_edges.device), scaled_edges])

# Calculate cumulative probabilities using standard normal CDF (probit function)
function_samples = function_samples.unsqueeze(-1)
scaled_edges_left = scaled_edges_left.reshape(1, -1)
scaled_edges_right = scaled_edges_right.reshape(1, -1)
probs = inv_probit(scaled_edges_left - function_samples / self.sigma) - inv_probit(
scaled_edges_right - function_samples / self.sigma
)

return Categorical(probs=probs)
30 changes: 30 additions & 0 deletions test/likelihoods/test_ordinal_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python3

import unittest

import torch
from torch.distributions import Distribution

from gpytorch.likelihoods import OrdinalLikelihood
from gpytorch.test.base_likelihood_test_case import BaseLikelihoodTestCase


class TestOrdinalLikelihood(BaseLikelihoodTestCase, unittest.TestCase):
seed = 0

def create_likelihood(self):
bin_edges = torch.tensor([-0.5, 0.5])
return OrdinalLikelihood(bin_edges)

def _create_targets(self, batch_shape=torch.Size([])):
return torch.distributions.Categorical(probs=torch.tensor([1 / 3, 1 / 3, 1 / 3])).sample(
torch.Size([*batch_shape, 5])
)

def _test_marginal(self, batch_shape):
likelihood = self.create_likelihood()
input = self._create_marginal_input(batch_shape)
output = likelihood(input)

self.assertTrue(isinstance(output, Distribution))
self.assertEqual(output.sample().shape[-len(batch_shape) - 1 :], torch.Size([*batch_shape, 5]))