Skip to content

Commit 6c2fd48

Browse files
committed
First unit test passing
1 parent b9dc064 commit 6c2fd48

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
from math import pi
5+
6+
import torch
7+
8+
import gpytorch
9+
from gpytorch.distributions import MultitaskMultivariateNormal
10+
from gpytorch.kernels import ScaleKernel, RBFKernelGrad
11+
from gpytorch.likelihoods import MultitaskGaussianLikelihood
12+
from gpytorch.means import ConstantMeanGrad
13+
from gpytorch.test.base_test_case import BaseTestCase
14+
15+
# Simple training data
16+
num_train_samples = 15
17+
num_fantasies = 10
18+
dim = 1
19+
train_X = torch.linspace(0, 1, num_train_samples).reshape(-1, 1)
20+
train_Y = torch.hstack([
21+
torch.sin(train_X * (2 * pi)).reshape(-1, 1),
22+
(2 * pi) * torch.cos(train_X * (2 * pi)).reshape(-1, 1),
23+
])
24+
25+
26+
class GPWithDerivatives(gpytorch.models.ExactGP):
27+
def __init__(self, train_X, train_Y):
28+
likelihood = MultitaskGaussianLikelihood(num_tasks=1 + dim)
29+
super().__init__(train_X, train_Y, likelihood)
30+
self.mean_module = ConstantMeanGrad()
31+
self.base_kernel = RBFKernelGrad()
32+
self.covar_module = ScaleKernel(self.base_kernel)
33+
self._num_outputs = 1 + dim
34+
35+
def forward(self, x):
36+
mean_x = self.mean_module(x)
37+
covar_x = self.covar_module(x)
38+
return MultitaskMultivariateNormal(mean_x, covar_x)
39+
40+
41+
class TestDerivativeGPFutures(BaseTestCase, unittest.TestCase):
42+
43+
# Inspired by test_lanczos_fantasy_model
44+
def test_derivative_gp_futures(self):
45+
model = GPWithDerivatives(train_X, train_Y)
46+
mll = gpytorch.mlls.sum_marginal_log_likelihood.ExactMarginalLogLikelihood(model.likelihood, model)
47+
48+
mll.train()
49+
mll.eval()
50+
51+
# get a posterior to fill in caches
52+
model(torch.randn(num_train_samples).reshape(-1, 1))
53+
54+
new_x = torch.randn((1, 1, dim))
55+
new_y = torch.randn((num_fantasies, 1, 1, 1 + dim))
56+
57+
# just check that this can run without error
58+
model.get_fantasy_model(new_x, new_y)
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

0 commit comments

Comments
 (0)