diff --git a/docs/source/likelihoods.rst b/docs/source/likelihoods.rst index 3b71b42ae..e2b343cf0 100644 --- a/docs/source/likelihoods.rst +++ b/docs/source/likelihoods.rst @@ -80,6 +80,12 @@ reduce the variance when computing approximate GP objective functions. .. autoclass:: StudentTLikelihood :members: +:hidden:`OrdinalLikelihood` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: OrdinalLikelihood + :members: + Multi-Dimensional Likelihoods ----------------------------- diff --git a/gpytorch/likelihoods/__init__.py b/gpytorch/likelihoods/__init__.py index 31a370079..93bb0da70 100644 --- a/gpytorch/likelihoods/__init__.py +++ b/gpytorch/likelihoods/__init__.py @@ -14,6 +14,7 @@ from .likelihood_list import LikelihoodList from .multitask_gaussian_likelihood import _MultitaskGaussianLikelihoodBase, MultitaskGaussianLikelihood from .noise_models import HeteroskedasticNoise +from .ordinal_likelihood import OrdinalLikelihood from .softmax_likelihood import SoftmaxLikelihood from .student_t_likelihood import StudentTLikelihood @@ -32,6 +33,7 @@ "Likelihood", "LikelihoodList", "MultitaskGaussianLikelihood", + "OrdinalLikelihood", "SoftmaxLikelihood", "StudentTLikelihood", ] diff --git a/gpytorch/likelihoods/ordinal_likelihood.py b/gpytorch/likelihoods/ordinal_likelihood.py new file mode 100644 index 000000000..4946345e8 --- /dev/null +++ b/gpytorch/likelihoods/ordinal_likelihood.py @@ -0,0 +1,108 @@ +from typing import Any, Dict, Optional + +import torch +from torch import Tensor +from torch.distributions import Categorical + +from ..constraints import Interval, Positive +from ..priors import Prior +from .likelihood import _OneDimensionalLikelihood + + +def inv_probit(x, jitter=1e-3): + """ + Inverse probit function (standard normal CDF) with jitter for numerical stability. + + Args: + x: Input tensor + jitter: Small constant to ensure outputs are strictly between 0 and 1 + + Returns: + Probabilities between jitter and 1-jitter + """ + return 0.5 * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0)))) * (1 - 2 * jitter) + jitter + + +class OrdinalLikelihood(_OneDimensionalLikelihood): + r""" + An ordinal likelihood for regressing over ordinal data. + + The data are integer values from :math:`0` to :math:`k`, and the user must specify :math:`(k-1)` + 'bin edges' which define the points at which the labels switch. Let the bin + edges be :math:`[a_0, a_1, ... a_{k-1}]`, then the likelihood is + + .. math:: + p(Y=0|F) &= \Phi((a_0 - F) / \sigma) + + p(Y=1|F) &= \Phi((a_1 - F) / \sigma) - \Phi((a_0 - F) / \sigma) + + p(Y=2|F) &= \Phi((a_2 - F) / \sigma) - \Phi((a_1 - F) / \sigma) + + ... + + p(Y=K|F) &= 1 - \Phi((a_{k-1} - F) / \sigma) + + where :math:`\Phi` is the cumulative density function of a Gaussian (the inverse probit + function) and :math:`\sigma` is a parameter to be learned. + + From Chu et Ghahramani, Journal of Machine Learning Research, 2005 + [https://www.jmlr.org/papers/volume6/chu05a/chu05a.pdf]. + + :param bin_edges: A tensor of shape :math:`(k-1)` containing the bin edges. + :param batch_shape: The batch shape of the learned sigma parameter (default: []). + :param sigma_prior: Prior for sigma parameter :math:`\sigma`. + :param sigma_constraint: Constraint for sigma parameter :math:`\sigma`. + + :ivar torch.Tensor bin_edges: :math:`\{a_i\}_{i=0}^{k-1}` bin edges + :ivar torch.Tensor sigma: :math:`\sigma` parameter (scale) + """ + + def __init__( + self, + bin_edges: Tensor, + batch_shape: torch.Size = torch.Size([]), + sigma_prior: Optional[Prior] = None, + sigma_constraint: Optional[Interval] = None, + ) -> None: + super().__init__() + + self.num_bins = len(bin_edges) + 1 + self.register_parameter("bin_edges", torch.nn.Parameter(bin_edges, requires_grad=False)) + + if sigma_constraint is None: + sigma_constraint = Positive() + + self.raw_sigma = torch.nn.Parameter(torch.ones(*batch_shape, 1)) + if sigma_prior is not None: + self.register_prior("sigma_prior", sigma_prior, lambda m: m.sigma, lambda m, v: m._set_sigma(v)) + + self.register_constraint("raw_sigma", sigma_constraint) + + @property + def sigma(self) -> Tensor: + return self.raw_sigma_constraint.transform(self.raw_sigma) + + @sigma.setter + def sigma(self, value: Tensor) -> None: + self._set_sigma(value) + + def _set_sigma(self, value: Tensor) -> None: + if not torch.is_tensor(value): + value = torch.as_tensor(value).to(self.raw_sigma) + self.initialize(raw_sigma=self.raw_sigma_constraint.inverse_transform(value)) + + def forward(self, function_samples: Tensor, *args: Any, data: Dict[str, Tensor] = {}, **kwargs: Any) -> Categorical: + # Compute scaled bin edges + scaled_edges = self.bin_edges / self.sigma + scaled_edges_left = torch.cat([scaled_edges, torch.tensor([torch.inf], device=scaled_edges.device)], dim=-1) + scaled_edges_right = torch.cat([torch.tensor([-torch.inf], device=scaled_edges.device), scaled_edges]) + + # Calculate cumulative probabilities using standard normal CDF (probit function) + function_samples = function_samples.unsqueeze(-1) + scaled_edges_left = scaled_edges_left.reshape(1, -1) + scaled_edges_right = scaled_edges_right.reshape(1, -1) + probs = inv_probit(scaled_edges_left - function_samples / self.sigma) - inv_probit( + scaled_edges_right - function_samples / self.sigma + ) + + return Categorical(probs=probs) diff --git a/test/likelihoods/test_ordinal_likelihood.py b/test/likelihoods/test_ordinal_likelihood.py new file mode 100644 index 000000000..3cd2b8b4c --- /dev/null +++ b/test/likelihoods/test_ordinal_likelihood.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from torch.distributions import Distribution + +from gpytorch.likelihoods import OrdinalLikelihood +from gpytorch.test.base_likelihood_test_case import BaseLikelihoodTestCase + + +class TestOrdinalLikelihood(BaseLikelihoodTestCase, unittest.TestCase): + seed = 0 + + def create_likelihood(self): + bin_edges = torch.tensor([-0.5, 0.5]) + return OrdinalLikelihood(bin_edges) + + def _create_targets(self, batch_shape=torch.Size([])): + return torch.distributions.Categorical(probs=torch.tensor([1 / 3, 1 / 3, 1 / 3])).sample( + torch.Size([*batch_shape, 5]) + ) + + def _test_marginal(self, batch_shape): + likelihood = self.create_likelihood() + input = self._create_marginal_input(batch_shape) + output = likelihood(input) + + self.assertTrue(isinstance(output, Distribution)) + self.assertEqual(output.sample().shape[-len(batch_shape) - 1 :], torch.Size([*batch_shape, 5]))