-
Notifications
You must be signed in to change notification settings - Fork 575
Adding ordinal likelihood #2639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
bendavidsteel
wants to merge
3
commits into
cornellius-gp:main
Choose a base branch
from
bendavidsteel:ordinal-likelihood
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
| from torch.distributions import Categorical | ||
|
|
||
| from ..constraints import Interval, Positive | ||
| from ..distributions import MultivariateNormal | ||
| from ..priors import Prior | ||
| from .likelihood import _OneDimensionalLikelihood | ||
|
|
||
|
|
||
| def inv_probit(x, jitter=1e-3): | ||
| """ | ||
| Inverse probit function (standard normal CDF) with jitter for numerical stability. | ||
|
|
||
| Args: | ||
| x: Input tensor | ||
| jitter: Small constant to ensure outputs are strictly between 0 and 1 | ||
|
|
||
| Returns: | ||
| Probabilities between jitter and 1-jitter | ||
| """ | ||
| return 0.5 * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0)))) * (1 - 2 * jitter) + jitter | ||
|
|
||
|
|
||
| class OrdinalLikelihood(_OneDimensionalLikelihood): | ||
| r""" | ||
| An ordinal likelihood for regressing over ordinal data. | ||
|
|
||
| The data are integer values from :math:`0` to :math:`k`, and the user must specify :math:`(k-1)` | ||
| 'bin edges' which define the points at which the labels switch. Let the bin | ||
| edges be :math:`[a_0, a_1, ... a_{k-1}]`, then the likelihood is | ||
|
|
||
| .. math:: | ||
| p(Y=0|F) &= \Phi((a_0 - F) / \sigma) | ||
|
|
||
| p(Y=1|F) &= \Phi((a_1 - F) / \sigma) - \Phi((a_0 - F) / \sigma) | ||
|
|
||
| p(Y=2|F) &= \Phi((a_2 - F) / \sigma) - \Phi((a_1 - F) / \sigma) | ||
|
|
||
| ... | ||
|
|
||
| p(Y=K|F) &= 1 - \Phi((a_{k-1} - F) / \sigma) | ||
|
|
||
| where :math:`\Phi` is the cumulative density function of a Gaussian (the inverse probit | ||
| function) and :math:`\sigma` is a parameter to be learned. | ||
|
|
||
| From Chu et Ghahramani, Journal of Machine Learning Research, 2005 | ||
| [https://www.jmlr.org/papers/volume6/chu05a/chu05a.pdf]. | ||
|
|
||
| :param bin_edges: A tensor of shape :math:`(k-1)` containing the bin edges. | ||
| :param batch_shape: The batch shape of the learned sigma parameter (default: []). | ||
| :param sigma_prior: Prior for sigma parameter :math:`\sigma`. | ||
| :param sigma_constraint: Constraint for sigma parameter :math:`\sigma`. | ||
|
|
||
| :ivar torch.Tensor bin_edges: :math:`\{a_i\}_{i=0}^{k-1}` bin edges | ||
| :ivar torch.Tensor sigma: :math:`\sigma` parameter (scale) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| bin_edges: Tensor, | ||
| batch_shape: torch.Size = torch.Size([]), | ||
| sigma_prior: Optional[Prior] = None, | ||
| sigma_constraint: Optional[Interval] = None, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| self.num_bins = len(bin_edges) + 1 | ||
| self.register_parameter("bin_edges", torch.nn.Parameter(bin_edges, requires_grad=False)) | ||
|
|
||
| if sigma_constraint is None: | ||
| sigma_constraint = Positive() | ||
|
|
||
| self.raw_sigma = torch.nn.Parameter(torch.ones(*batch_shape, 1)) | ||
| if sigma_prior is not None: | ||
| self.register_prior("sigma_prior", sigma_prior, lambda m: m.sigma, lambda m, v: m._set_sigma(v)) | ||
|
|
||
| self.register_constraint("raw_sigma", sigma_constraint) | ||
|
|
||
| @property | ||
| def sigma(self) -> Tensor: | ||
| return self.raw_sigma_constraint.transform(self.raw_sigma) | ||
|
|
||
| @sigma.setter | ||
| def sigma(self, value: Tensor) -> None: | ||
| self._set_sigma(value) | ||
|
|
||
| def _set_sigma(self, value: Tensor) -> None: | ||
| if not torch.is_tensor(value): | ||
| value = torch.as_tensor(value).to(self.raw_sigma) | ||
| self.initialize(raw_sigma=self.raw_sigma_constraint.inverse_transform(value)) | ||
|
|
||
| def forward(self, function_samples: Tensor, *args: Any, data: Dict[str, Tensor] = {}, **kwargs: Any) -> Categorical: | ||
| if isinstance(function_samples, MultivariateNormal): | ||
| function_samples = function_samples.sample() | ||
|
|
||
| # Compute scaled bin edges | ||
| scaled_edges = self.bin_edges / self.sigma | ||
| scaled_edges_left = torch.cat([scaled_edges, torch.tensor([torch.inf], device=scaled_edges.device)], dim=-1) | ||
| scaled_edges_right = torch.cat([torch.tensor([-torch.inf], device=scaled_edges.device), scaled_edges]) | ||
|
|
||
| # Calculate cumulative probabilities using standard normal CDF (probit function) | ||
| function_samples = function_samples.unsqueeze(-1) | ||
| scaled_edges_left = scaled_edges_left.reshape(1, -1) | ||
| scaled_edges_right = scaled_edges_right.reshape(1, -1) | ||
| probs = inv_probit(scaled_edges_left - function_samples / self.sigma) - inv_probit( | ||
| scaled_edges_right - function_samples / self.sigma | ||
| ) | ||
|
|
||
| return Categorical(probs=probs) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| #!/usr/bin/env python3 | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from torch.distributions import Distribution | ||
|
|
||
| from gpytorch.likelihoods import OrdinalLikelihood | ||
| from gpytorch.test.base_likelihood_test_case import BaseLikelihoodTestCase | ||
|
|
||
|
|
||
| class TestOrdinalLikelihood(BaseLikelihoodTestCase, unittest.TestCase): | ||
| seed = 0 | ||
|
|
||
| def create_likelihood(self): | ||
| bin_edges = torch.tensor([-0.5, 0.5]) | ||
| return OrdinalLikelihood(bin_edges) | ||
|
|
||
| def _create_targets(self, batch_shape=torch.Size([])): | ||
| return torch.distributions.Categorical(probs=torch.tensor([1 / 3, 1 / 3, 1 / 3])).sample( | ||
| torch.Size([*batch_shape, 5]) | ||
| ) | ||
|
|
||
| def _test_marginal(self, batch_shape): | ||
| likelihood = self.create_likelihood() | ||
| input = self._create_marginal_input(batch_shape) | ||
| output = likelihood(input) | ||
|
|
||
| self.assertTrue(isinstance(output, Distribution)) | ||
| self.assertEqual(output.sample().shape[-len(batch_shape) - 1 :], torch.Size([*batch_shape, 5])) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.