Skip to content

Commit 328ebd0

Browse files
committed
fix failing unittests and increase coverage
1 parent d8bd497 commit 328ebd0

File tree

7 files changed

+118
-23
lines changed

7 files changed

+118
-23
lines changed

gpytorch/distributions/multitask_multivariate_normal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,12 @@ def get_base_samples(self, sample_shape=torch.Size()):
203203
return base_samples.view(new_shape).transpose(-1, -2).contiguous()
204204
return base_samples.view(*sample_shape, *self._output_shape)
205205

206-
def log_prob(self, value):
206+
def log_prob(self, value, combine_terms=True):
207207
if not self._interleaved:
208208
# flip shape of last two dimensions
209209
new_shape = value.shape[:-2] + value.shape[:-3:-1]
210210
value = value.view(new_shape).transpose(-1, -2).contiguous()
211-
return super().log_prob(value.view(*value.shape[:-2], -1))
211+
return super().log_prob(value.view(*value.shape[:-2], -1), combine_terms)
212212

213213
@property
214214
def mean(self):

gpytorch/distributions/multivariate_normal.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def log_prob(self, value, combine_terms=True):
167167
# Get log determininant and first part of quadratic form
168168
covar = covar.evaluate_kernel()
169169
inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
170-
norm_const = diff.size(-1) * math.log(2 * math.pi)
170+
norm_const = torch.tensor(
171+
diff.size(-1) * math.log(2 * math.pi)
172+
).to(inv_quad)
171173
split_terms = [inv_quad, logdet, norm_const]
172174

173175
if combine_terms:

gpytorch/mlls/exact_marginal_log_likelihood.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood):
1919
2020
:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model
2121
:param ~gpytorch.models.ExactGP model: The exact GP model
22+
:param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately
2223
2324
Example:
2425
>>> # model is a gpytorch.models.ExactGP
@@ -30,10 +31,10 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood):
3031
>>> loss.backward()
3132
"""
3233

33-
def __init__(self, likelihood, model):
34+
def __init__(self, likelihood, model, combine_terms=True):
3435
if not isinstance(likelihood, _GaussianLikelihoodBase):
3536
raise RuntimeError("Likelihood must be Gaussian for exact inference")
36-
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model)
37+
super(ExactMarginalLogLikelihood, self).__init__(likelihood, model, combine_terms)
3738

3839
def _add_other_terms(self, res, params):
3940
# Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models)

gpytorch/mlls/leave_one_out_pseudo_likelihood.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class LeaveOneOutPseudoLikelihood(ExactMarginalLogLikelihood):
2929
3030
:param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model
3131
:param ~gpytorch.models.ExactGP model: The exact GP model
32+
:param ~bool combine_terms (optional): If `False`, the MLL call returns each MLL term separately
3233
3334
Example:
3435
>>> # model is a gpytorch.models.ExactGP
@@ -40,11 +41,6 @@ class LeaveOneOutPseudoLikelihood(ExactMarginalLogLikelihood):
4041
>>> loss.backward()
4142
"""
4243

43-
def __init__(self, likelihood, model):
44-
super().__init__(likelihood=likelihood, model=model)
45-
self.likelihood = likelihood
46-
self.model = model
47-
4844
def forward(self, function_dist: MultivariateNormal, target: Tensor, *params) -> Tensor:
4945
r"""
5046
Computes the leave one out likelihood given :math:`p(\mathbf f)` and `\mathbf y`
@@ -60,12 +56,16 @@ def forward(self, function_dist: MultivariateNormal, target: Tensor, *params) ->
6056
identity = torch.eye(*L.shape[-2:], dtype=m.dtype, device=m.device)
6157
sigma2 = 1.0 / L._cholesky_solve(identity, upper=False).diagonal(dim1=-1, dim2=-2) # 1 / diag(inv(K))
6258
mu = target - L._cholesky_solve((target - m).unsqueeze(-1), upper=False).squeeze(-1) * sigma2
63-
term1 = -0.5 * sigma2.log()
64-
term2 = -0.5 * (target - mu).pow(2.0) / sigma2
65-
res = (term1 + term2).sum(dim=-1)
66-
67-
res = self._add_other_terms(res, params)
6859

6960
# Scale by the amount of data we have and then add on the scaled constant
7061
num_data = target.size(-1)
71-
return res.div_(num_data) - 0.5 * math.log(2 * math.pi)
62+
term1 = sigma2.log().sum(-1)
63+
term2 = ((target - mu).pow(2.0) / sigma2).sum(-1)
64+
norm_const = torch.tensor(num_data * math.log(2 * math.pi)).to(term1)
65+
other_term = self._add_other_terms(torch.zeros_like(term1), params)
66+
split_terms = [term1, term2, norm_const, other_term]
67+
68+
if self.combine_terms:
69+
return -0.5 / num_data * sum(split_terms)
70+
else:
71+
return [-0.5 / num_data * term for term in split_terms]

test/distributions/test_multivariate_normal.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,20 +219,32 @@ def test_log_prob(self, cuda=False):
219219
var = torch.randn(4, device=device, dtype=dtype).abs_()
220220
values = torch.randn(4, device=device, dtype=dtype)
221221

222-
res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values)
222+
mvn = MultivariateNormal(mean, DiagLazyTensor(var))
223+
res = mvn.log_prob(values)
223224
actual = TMultivariateNormal(mean, torch.eye(4, device=device, dtype=dtype) * var).log_prob(values)
224225
self.assertLess((res - actual).div(res).abs().item(), 1e-2)
225226

227+
res2 = mvn.log_prob(values, combine_terms=False)
228+
assert len(res2) == 3
229+
res2 = sum(res2)
230+
self.assertLess((res2 - actual).div(res).abs().item(), 1e-2)
231+
226232
mean = torch.randn(3, 4, device=device, dtype=dtype)
227233
var = torch.randn(3, 4, device=device, dtype=dtype).abs_()
228234
values = torch.randn(3, 4, device=device, dtype=dtype)
229235

230-
res = MultivariateNormal(mean, DiagLazyTensor(var)).log_prob(values)
236+
mvn = MultivariateNormal(mean, DiagLazyTensor(var))
237+
res = mvn.log_prob(values)
231238
actual = TMultivariateNormal(
232239
mean, var.unsqueeze(-1) * torch.eye(4, device=device, dtype=dtype).repeat(3, 1, 1)
233240
).log_prob(values)
234241
self.assertLess((res - actual).div(res).abs().norm(), 1e-2)
235242

243+
res2 = mvn.log_prob(values, combine_terms=False)
244+
assert len(res2) == 3
245+
res2 = sum(res2)
246+
self.assertLess((res2 - actual).div(res).abs().norm(), 1e-2)
247+
236248
def test_log_prob_cuda(self):
237249
if torch.cuda.is_available():
238250
with least_used_cuda_device():
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import unittest
2+
3+
import torch
4+
5+
import gpytorch
6+
7+
from .test_leave_one_out_pseudo_likelihood import ExactGPModel
8+
9+
10+
class TestExactMarginalLogLikelihood(unittest.TestCase):
11+
def get_data(self, shapes, combine_terms, dtype=None, device=None):
12+
train_x = torch.rand(*shapes, dtype=dtype, device=device, requires_grad=True)
13+
train_y = torch.sin(train_x[..., 0]) + torch.cos(train_x[..., 1])
14+
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(dtype=dtype, device=device)
15+
model = ExactGPModel(train_x, train_y, likelihood).to(dtype=dtype, device=device)
16+
exact_mll = gpytorch.mlls.ExactMarginalLogLikelihood(
17+
likelihood=likelihood,
18+
model=model,
19+
combine_terms=combine_terms
20+
)
21+
return train_x, train_y, exact_mll
22+
23+
def test_smoke(self):
24+
"""Make sure the exact_mll works without batching."""
25+
train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=True)
26+
output = exact_mll.model(train_x)
27+
loss = -exact_mll(output, train_y)
28+
loss.backward()
29+
self.assertTrue(train_x.grad is not None)
30+
31+
train_x, train_y, exact_mll = self.get_data([5, 2], combine_terms=False)
32+
output = exact_mll.model(train_x)
33+
mll_out = exact_mll(output, train_y)
34+
loss = -1 * sum(mll_out)
35+
loss.backward()
36+
assert len(mll_out) == 4
37+
self.assertTrue(train_x.grad is not None)
38+
39+
def test_smoke_batch(self):
40+
"""Make sure the exact_mll works without batching."""
41+
train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=True)
42+
output = exact_mll.model(train_x)
43+
loss = -exact_mll(output, train_y)
44+
assert loss.shape == (3, 3, 3)
45+
loss.sum().backward()
46+
self.assertTrue(train_x.grad is not None)
47+
48+
train_x, train_y, exact_mll = self.get_data([3, 3, 3, 5, 2], combine_terms=False)
49+
output = exact_mll.model(train_x)
50+
mll_out = exact_mll(output, train_y)
51+
loss = -1 * sum(mll_out)
52+
assert len(mll_out) == 4
53+
assert loss.shape == (3, 3, 3)
54+
loss.sum().backward()
55+
self.assertTrue(train_x.grad is not None)
56+
57+
58+
if __name__ == "__main__":
59+
unittest.main()

test/mlls/test_leave_one_out_pseudo_likelihood.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,36 +21,57 @@ def forward(self, x):
2121

2222

2323
class TestLeaveOneOutPseudoLikelihood(unittest.TestCase):
24-
def get_data(self, shapes, dtype=None, device=None):
24+
def get_data(self, shapes, combine_terms, dtype=None, device=None):
2525
train_x = torch.rand(*shapes, dtype=dtype, device=device, requires_grad=True)
2626
train_y = torch.sin(train_x[..., 0]) + torch.cos(train_x[..., 1])
2727
likelihood = gpytorch.likelihoods.GaussianLikelihood().to(dtype=dtype, device=device)
2828
model = ExactGPModel(train_x, train_y, likelihood).to(dtype=dtype, device=device)
29-
loocv = gpytorch.mlls.LeaveOneOutPseudoLikelihood(likelihood=likelihood, model=model)
29+
loocv = gpytorch.mlls.LeaveOneOutPseudoLikelihood(
30+
likelihood=likelihood,
31+
model=model,
32+
combine_terms=combine_terms
33+
)
3034
return train_x, train_y, loocv
3135

3236
def test_smoke(self):
3337
"""Make sure the loocv works without batching."""
34-
train_x, train_y, loocv = self.get_data([5, 2])
38+
train_x, train_y, loocv = self.get_data([5, 2], combine_terms=True)
3539
output = loocv.model(train_x)
3640
loss = -loocv(output, train_y)
3741
loss.backward()
3842
self.assertTrue(train_x.grad is not None)
3943

44+
train_x, train_y, loocv = self.get_data([5, 2], combine_terms=False)
45+
output = loocv.model(train_x)
46+
mll_out = loocv(output, train_y)
47+
loss = -1 * sum(mll_out)
48+
loss.backward()
49+
assert len(mll_out) == 4
50+
self.assertTrue(train_x.grad is not None)
51+
4052
def test_smoke_batch(self):
4153
"""Make sure the loocv works without batching."""
42-
train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2])
54+
train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2], combine_terms=True)
4355
output = loocv.model(train_x)
4456
loss = -loocv(output, train_y)
4557
assert loss.shape == (3, 3, 3)
4658
loss.sum().backward()
4759
self.assertTrue(train_x.grad is not None)
4860

61+
train_x, train_y, loocv = self.get_data([3, 3, 3, 5, 2], combine_terms=False)
62+
output = loocv.model(train_x)
63+
mll_out = loocv(output, train_y)
64+
loss = -1 * sum(mll_out)
65+
assert len(mll_out) == 4
66+
assert loss.shape == (3, 3, 3)
67+
loss.sum().backward()
68+
self.assertTrue(train_x.grad is not None)
69+
4970
def test_check_bordered_system(self):
5071
"""Make sure that the bordered system solves match the naive solution."""
5172
n = 5
5273
# Compute the pseudo-likelihood via the bordered systems in O(n^3)
53-
train_x, train_y, loocv = self.get_data([n, 2], dtype=torch.float64)
74+
train_x, train_y, loocv = self.get_data([n, 2], combine_terms=True, dtype=torch.float64)
5475
output = loocv.model(train_x)
5576
loocv_1 = loocv(output, train_y)
5677

0 commit comments

Comments
 (0)