Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions gpytorch/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,15 @@ def dsmm(sparse_mat, dense_mat):
return DSMM(sparse_mat)(dense_mat)


def exact_predictive_mean(full_covar, full_mean, train_labels, num_train, likelihood, precomputed_cache=None):
def exact_predictive_mean(full_covar, full_mean, train_inputs, train_labels, num_train, likelihood, precomputed_cache=None):
"""
Computes the posterior predictive mean of a GP

Args:
- full_covar ( (n+t) x (n+t) ) - the block prior covariance matrix of training and testing points
[ K_XX, K_XX*; K_X*X, K_X*X* ]
- full_mean (n + t) - the training and test prior means, stacked on top of each other
- train_inputs TODO
- train_labels (n) - the training labels minus the training prior mean
- noise (1) - the observed noise (from the likelihood)
- precomputed_cache - speeds up subsequent computations (default: None)
Expand All @@ -88,16 +89,17 @@ def exact_predictive_mean(full_covar, full_mean, train_labels, num_train, likeli
from ..lazy.non_lazy_tensor import NonLazyTensor

full_covar = NonLazyTensor(full_covar)
return full_covar.exact_predictive_mean(full_mean, train_labels, num_train, likelihood, precomputed_cache)
return full_covar.exact_predictive_mean(full_mean, train_inputs, train_labels, num_train, likelihood, precomputed_cache)


def exact_predictive_covar(full_covar, num_train, likelihood, precomputed_cache=None):
def exact_predictive_covar(full_covar, train_inputs, num_train, likelihood, precomputed_cache=None):
"""
Computes the posterior predictive covariance of a GP

Args:
- full_covar ( (n+t) x (n+t) ) - the block prior covariance matrix of training and testing points
[ K_XX, K_XX*; K_X*X, K_X*X* ]
- train_inputs TODO
- num_train (int) - how many training points are there in the full covariance matrix
- noise (1) - the observed noise (from the likelihood)
- precomputed_cache - speeds up subsequent computations (default: None)
Expand All @@ -112,7 +114,7 @@ def exact_predictive_covar(full_covar, num_train, likelihood, precomputed_cache=
from ..lazy.non_lazy_tensor import NonLazyTensor

full_covar = NonLazyTensor(full_covar)
return full_covar.exact_predictive_covar(num_train, likelihood, precomputed_cache)
return full_covar.exact_predictive_covar(train_inputs, num_train, likelihood, precomputed_cache)


def log_normal_cdf(x):
Expand Down
6 changes: 6 additions & 0 deletions gpytorch/lazy/diag_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ def sum_batch(self, sum_batch_size=None):

return self.__class__(diag.sum(-2))

def exp(self):
return DiagLazyTensor(self._diag.exp())

def sqrt(self):
return DiagLazyTensor(self._diag.sqrt())

def zero_mean_mvn_samples(self, num_samples):
if self.ndimension() == 3:
base_samples = torch.randn(
Expand Down
10 changes: 5 additions & 5 deletions gpytorch/lazy/interpolated_lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def diag(self):
res = res.view(batch_size, n_data, -1).sum(-1)
return res

def exact_predictive_mean(self, full_mean, train_labels, num_train, likelihood, precomputed_cache=None):
def exact_predictive_mean(self, full_mean, train_inputs, train_labels, num_train, likelihood, precomputed_cache=None):
from ..distributions import MultivariateNormal

if precomputed_cache is None:
Expand All @@ -383,7 +383,7 @@ def exact_predictive_mean(self, full_mean, train_labels, num_train, likelihood,

train_mean = full_mean.narrow(-1, 0, train_train_covar.size(-1))

mvn = likelihood(MultivariateNormal(train_mean, train_train_covar))
mvn = likelihood(MultivariateNormal(train_mean, train_train_covar), train_inputs)
train_mean, train_train_covar = mvn.mean, mvn.lazy_covariance_matrix

train_train_covar_inv_labels = train_train_covar.inv_matmul((train_labels - train_mean).unsqueeze(-1))
Expand Down Expand Up @@ -423,11 +423,11 @@ def _exact_predictive_covar_inv_quad_form_root(self, precomputed_cache, test_tra
res = left_interp(test_interp_indices, test_interp_values, precomputed_cache)
return res

def exact_predictive_covar(self, num_train, likelihood, precomputed_cache=None):
def exact_predictive_covar(self, train_inputs, num_train, likelihood, precomputed_cache=None):
from ..distributions import MultivariateNormal

if not beta_features.fast_pred_var.on() and not beta_features.fast_pred_samples.on():
return super(InterpolatedLazyTensor, self).exact_predictive_covar(num_train, likelihood, precomputed_cache)
return super(InterpolatedLazyTensor, self).exact_predictive_covar(train_inputs, num_train, likelihood, precomputed_cache)

n_test = self.size(-2) - num_train
train_interp_indices = self.left_interp_indices.narrow(-2, 0, num_train)
Expand All @@ -453,7 +453,7 @@ def exact_predictive_covar(self, num_train, likelihood, precomputed_cache=None):
)

grv = MultivariateNormal(torch.zeros(1), train_train_covar)
train_train_covar = likelihood(grv).lazy_covariance_matrix
train_train_covar = likelihood(grv, train_inputs).lazy_covariance_matrix

# Get probe vectors for inverse root
num_probe_vectors = beta_features.fast_pred_var.num_probe_vectors()
Expand Down
23 changes: 17 additions & 6 deletions gpytorch/lazy/lazy_evaluated_kernel_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ def _quad_form_derivative(self, left_vecs, right_vecs):
def _transpose_nonbatch(self):
return self.__class__(self.kernel, self.x2, self.x1, **self.params)

def _batch_get_indices(self, batch_indices, left_indices, right_indices):
from ..kernels import Kernel

x1 = self.x1[batch_indices, left_indices, :].unsqueeze(0)
x2 = self.x2[batch_indices, right_indices, :].unsqueeze(0)
res = super(Kernel, self.kernel).__call__(x1.transpose(-1, -2), x2.transpose(-1, -2))
if isinstance(res, LazyTensor):
res = res.evaluate()
res = res.view(-1)
return res

def _get_indices(self, left_indices, right_indices):
from ..kernels import Kernel

Expand Down Expand Up @@ -168,22 +179,22 @@ def representation_tree(self):
def evaluate(self):
return self.evaluate_kernel().evaluate()

def exact_predictive_mean(self, full_mean, train_labels, num_train, likelihood, precomputed_cache=None):
def exact_predictive_mean(self, full_mean, train_inputs, train_labels, num_train, likelihood, precomputed_cache=None):
if self.kernel.has_custom_exact_predictions:
return self.evaluate_kernel().exact_predictive_mean(
full_mean, train_labels, num_train, likelihood, precomputed_cache
full_mean, train_inputs, train_labels, num_train, likelihood, precomputed_cache
)
else:
return super(LazyEvaluatedKernelTensor, self).exact_predictive_mean(
full_mean, train_labels, num_train, likelihood, precomputed_cache
full_mean, train_inputs, train_labels, num_train, likelihood, precomputed_cache
)

def exact_predictive_covar(self, num_train, likelihood, precomputed_cache=None):
def exact_predictive_covar(self, train_inputs, num_train, likelihood, precomputed_cache=None):
if self.kernel.has_custom_exact_predictions:
return self.evaluate_kernel().exact_predictive_covar(num_train, likelihood, precomputed_cache)
return self.evaluate_kernel().exact_predictive_covar(train_inputs, num_train, likelihood, precomputed_cache)
else:
return super(LazyEvaluatedKernelTensor, self).exact_predictive_covar(
num_train, likelihood, precomputed_cache
train_inputs, num_train, likelihood, precomputed_cache
)

def repeat(self, *sizes):
Expand Down
10 changes: 6 additions & 4 deletions gpytorch/lazy/lazy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,14 +428,15 @@ def evaluate_kernel(self):
"""
return self.representation_tree()(*self.representation())

def exact_predictive_mean(self, full_mean, train_labels, num_train, likelihood, precomputed_cache=None):
def exact_predictive_mean(self, full_mean, train_inputs, train_labels, num_train, likelihood, precomputed_cache=None):
"""
Computes the posterior predictive covariance of a GP
Assumes that self is the block prior covariance matrix of training and testing points
[ K_XX, K_XX*; K_X*X, K_X*X* ]

Args:
full_mean (:obj:`torch.tensor`): the training and test prior means, stacked on top of each other
train_inputs TODO
train_labels (:obj:`torch.tensor`): the training labels minus the training prior mean
noise (:obj:`torch.tensor`): the observed noise (from the likelihood)
precomputed_cache (optional): speeds up subsequent computations (default: None)
Expand All @@ -453,7 +454,7 @@ def exact_predictive_mean(self, full_mean, train_labels, num_train, likelihood,
train_train_covar = self[:num_train, :num_train]

train_mean = full_mean.narrow(-1, 0, train_train_covar.size(-1))
mvn = likelihood(MultivariateNormal(train_mean, train_train_covar))
mvn = likelihood(MultivariateNormal(train_mean, train_train_covar), train_inputs)
train_mean, train_train_covar = mvn.mean, mvn.lazy_covariance_matrix

train_labels_offset = train_labels - train_mean
Expand All @@ -472,13 +473,14 @@ def exact_predictive_mean(self, full_mean, train_labels, num_train, likelihood,
res = res + test_mean
return res, precomputed_cache.detach()

def exact_predictive_covar(self, num_train, likelihood, precomputed_cache=None):
def exact_predictive_covar(self, train_inputs, num_train, likelihood, precomputed_cache=None):
"""
Computes the posterior predictive covariance of a GP
Assumes that self is the block prior covariance matrix of training and testing points
[ K_XX, K_XX*; K_X*X, K_X*X* ]

Args:
train_inputs TODO - CAN GET RID OF num_train arg here as well!
num_train (int): The number of training points in the full covariance matrix
noise (scalar): The observed noise (from the likelihood)
precomputed_cache (optional): speeds up subsequent computations (default: None)
Expand All @@ -498,7 +500,7 @@ def exact_predictive_covar(self, num_train, likelihood, precomputed_cache=None):
test_train_covar = self[num_train:, :num_train]
test_test_covar = self[num_train:, num_train:]

train_train_covar = likelihood(MultivariateNormal(torch.zeros(1), train_train_covar)).lazy_covariance_matrix
train_train_covar = likelihood(MultivariateNormal(torch.zeros(1), train_train_covar), train_inputs).lazy_covariance_matrix
if not beta_features.fast_pred_var.on():
from .matmul_lazy_tensor import MatmulLazyTensor

Expand Down
25 changes: 16 additions & 9 deletions gpytorch/likelihoods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import, division, print_function, unicode_literals

from .likelihood import Likelihood
from .gaussian_likelihood import GaussianLikelihood
from .multitask_gaussian_likelihood import MultitaskGaussianLikelihood
from .bernoulli_likelihood import BernoulliLikelihood
from .gaussian_likelihood import _GaussianLikelihoodBase, GaussianLikelihood
from .likelihood import Likelihood
from .multitask_gaussian_likelihood import (
_MultitaskGaussianLikelihoodBase,
MultitaskGaussianLikelihood,
MultitaskGaussianLikelihood_Kronecker,
)
from .noise_models import HeteroskedasticNoise
from .softmax_likelihood import SoftmaxLikelihood


__all__ = [
"Likelihood",
"_GaussianLikelihoodBase",
"_MultitaskGaussianLikelihoodBase",
"BernoulliLikelihood",
"GaussianLikelihood",
"HeteroskedasticNoise",
"Likelihood",
"MultitaskGaussianLikelihood",
"BernoulliLikelihood",
"MultitaskGaussianLikelihood_Kronecker",
"SoftmaxLikelihood",
]
55 changes: 29 additions & 26 deletions gpytorch/likelihoods/gaussian_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,45 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import math
import torch
from ..distributions import MultivariateNormal
from ..functions import add_diag
from ..likelihoods import Likelihood
from .. import settings
import warnings
import logging

from .. import settings
from ..distributions import MultivariateNormal
from ..lazy import AddedDiagLazyTensor, DiagLazyTensor
from .likelihood import Likelihood
from .noise_models import HomoskedasticNoise

class GaussianLikelihood(Likelihood):
r"""
"""

def __init__(self, log_noise_prior=None, batch_size=1):
super(GaussianLikelihood, self).__init__()
self.register_parameter(
name="log_noise", parameter=torch.nn.Parameter(torch.zeros(batch_size, 1)), prior=log_noise_prior
)

@property
def noise(self):
return self.log_noise.exp()
class _GaussianLikelihoodBase(Likelihood):
def __init__(self, log_noise_covar):
super(_GaussianLikelihoodBase, self).__init__()
self.log_noise_covar = log_noise_covar

def forward(self, input):
def forward(self, input, *params):
if not isinstance(input, MultivariateNormal):
raise ValueError("GaussianLikelihood requires a MultivariateNormal input")
raise ValueError("Gaussian Likelihoods require a MultivariateNormal input")
mean, covar = input.mean, input.lazy_covariance_matrix
noise = self.noise
if covar.ndimension() == 2:
if settings.debug.on() and noise.size(0) > 1:
raise RuntimeError("With batch_size > 1, expected a batched MultivariateNormal distribution.")
noise = noise.squeeze(0)
log_noise_covar = self.log_noise_covar(*params)
if isinstance(log_noise_covar, DiagLazyTensor):
full_covar = AddedDiagLazyTensor(covar, log_noise_covar.exp())
else:
# TODO: Poperly deal with non-diagonal noise covariance models
full_covar = covar + log_noise_covar.exp()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to do full_covar = full_covar.add_jitter() after this call so that we can use preconditioning.

return input.__class__(mean, full_covar)

return input.__class__(mean, add_diag(covar, noise))
def variational_log_probability(self, input, target):
raise NotImplementedError


class GaussianLikelihood(_GaussianLikelihoodBase):
def __init__(self, log_noise_prior=None, batch_size=1):
log_noise_covar = HomoskedasticNoise(log_noise_prior=log_noise_prior, batch_size=1)
super(GaussianLikelihood, self).__init__(log_noise_covar=log_noise_covar)

def variational_log_probability(self, input, target):
mean, variance = input.mean, input.variance
log_noise = self.log_noise
log_noise = self.log_noise_covar.log_noise
if variance.ndimension() == 1:
if settings.debug.on() and log_noise.size(0) > 1:
raise RuntimeError("With batch_size > 1, expected a batched MultivariateNormal distribution.")
Expand Down
26 changes: 26 additions & 0 deletions gpytorch/likelihoods/homoskedastic_noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import torch
from torch.nn import Parameter

from ..lazy import DiagLazyTensor
from ..module import Module


class HomoskedasticNoise(Module):
def __init__(self, log_noise_prior=None, batch_size=1):
super(HomoskedasticNoise, self).__init__()
self.register_parameter(
name="log_noise", parameter=Parameter(torch.zeros(batch_size, 1)), prior=log_noise_prior
)

def forward(self, params):
noise = self.log_noise.exp()
if isinstance(params, list):
variance_shape = params[0].shape[:-2] + params[0].shape[-1:]
else:
variance_shape = params.shape[:-2] + params.shape[-1:]
if len(variance_shape) == 1:
noise = noise.squeeze(0)
variances = noise * torch.ones(*variance_shape, dtype=noise.dtype, device=noise.device)
return DiagLazyTensor(variances)
Loading