Skip to content

Commit c6a4f46

Browse files
Merge pull request #2317 from yyexela/dGPFantasize
Enable fantasy models for multitask GPs Reborn
2 parents 527546e + 28ee4ca commit c6a4f46

File tree

3 files changed

+95
-11
lines changed

3 files changed

+95
-11
lines changed

gpytorch/models/exact_gp.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from .. import settings
9-
from ..distributions import MultivariateNormal
9+
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
1010
from ..likelihoods import _GaussianLikelihoodBase
1111
from ..utils.generic import length_safe_zip
1212
from ..utils.warnings import GPInputWarning
@@ -162,15 +162,17 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
162162

163163
model_batch_shape = self.train_inputs[0].shape[:-2]
164164

165-
if self.train_targets.dim() > len(model_batch_shape) + 1:
166-
raise RuntimeError("Cannot yet add fantasy observations to multitask GPs, but this is coming soon!")
167-
168165
if not isinstance(inputs, list):
169166
inputs = [inputs]
170167

171168
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in inputs]
172169

173-
target_batch_shape = targets.shape[:-1]
170+
if not isinstance(self.prediction_strategy.train_prior_dist, MultitaskMultivariateNormal):
171+
data_dim_start = -1
172+
else:
173+
data_dim_start = -2
174+
175+
target_batch_shape = targets.shape[:data_dim_start]
174176
input_batch_shape = inputs[0].shape[:-2]
175177
tbdim, ibdim = len(target_batch_shape), len(input_batch_shape)
176178

@@ -198,7 +200,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
198200
# computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
199201
# size of the fantasy model here - this is done below, after the evaluation and fast fantasy update
200202
train_inputs = [tin.expand(input_batch_shape + tin.shape[-2:]) for tin in self.train_inputs]
201-
train_targets = self.train_targets.expand(target_batch_shape + self.train_targets.shape[-1:])
203+
train_targets = self.train_targets.expand(target_batch_shape + self.train_targets.shape[data_dim_start:])
202204

203205
full_inputs = [
204206
torch.cat(
@@ -208,8 +210,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
208210
for train_input, input in length_safe_zip(train_inputs, inputs)
209211
]
210212
full_targets = torch.cat(
211-
[train_targets, targets.expand(target_batch_shape + targets.shape[-1:])],
212-
dim=-1,
213+
[train_targets, targets.expand(target_batch_shape + targets.shape[data_dim_start:])], dim=data_dim_start
213214
)
214215

215216
try:

gpytorch/models/exact_prediction_strategies.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from torch import Tensor
2323

2424
from .. import settings
25+
26+
from ..distributions import MultitaskMultivariateNormal
2527
from ..lazy import LazyEvaluatedKernelTensor
2628
from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache
2729

@@ -134,16 +136,28 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
134136
A `DefaultPredictionStrategy` model with `n + m` training examples, where the `m` fantasy examples have
135137
been added and all test-time caches have been updated.
136138
"""
139+
if not isinstance(full_output, MultitaskMultivariateNormal):
140+
target_batch_shape = targets.shape[:-1]
141+
else:
142+
target_batch_shape = targets.shape[:-2]
143+
137144
full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix
138145

139146
batch_shape = full_inputs[0].shape[:-2]
140147

141-
full_mean = full_mean.view(*batch_shape, -1)
142148
num_train = self.num_train
143149

150+
if isinstance(full_output, MultitaskMultivariateNormal):
151+
num_tasks = full_output.event_shape[-1]
152+
full_mean = full_mean.view(*batch_shape, -1, num_tasks)
153+
fant_mean = full_mean[..., (num_train // num_tasks) :, :]
154+
full_targets = full_targets.view(*target_batch_shape, -1)
155+
else:
156+
full_mean = full_mean.view(*batch_shape, -1)
157+
fant_mean = full_mean[..., num_train:]
158+
144159
# Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated.
145160
fant_fant_covar = full_covar[..., num_train:, num_train:]
146-
fant_mean = full_mean[..., num_train:]
147161
mvn = self.train_prior_dist.__class__(fant_mean, fant_fant_covar)
148162
fant_likelihood = self.likelihood.get_fantasy_likelihood(**kwargs)
149163
mvn_obs = fant_likelihood(mvn, inputs, **kwargs)
@@ -209,6 +223,9 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
209223
new_root = BatchRepeatLinearOperator(DenseLinearOperator(new_root), repeat_shape)
210224
# no need to repeat the covar cache, broadcasting will do the right thing
211225

226+
if isinstance(full_output, MultitaskMultivariateNormal):
227+
full_mean = full_mean.view(*target_batch_shape, -1, num_tasks).contiguous()
228+
212229
# Create new DefaultPredictionStrategy object
213230
fant_strat = self.__class__(
214231
train_inputs=full_inputs,
@@ -285,7 +302,11 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera
285302
# NOTE TO FUTURE SELF:
286303
# You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact
287304
# GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no!
288-
res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1)
305+
306+
if len(self.mean_cache.shape) == 4:
307+
res = (test_train_covar @ self.mean_cache.squeeze(1).unsqueeze(-1)).squeeze(-1)
308+
else:
309+
res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1)
289310
res = res + test_mean
290311

291312
return res
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
from math import pi
5+
6+
import torch
7+
8+
import gpytorch
9+
from gpytorch.distributions import MultitaskMultivariateNormal
10+
from gpytorch.kernels import ScaleKernel, RBFKernelGrad
11+
from gpytorch.likelihoods import MultitaskGaussianLikelihood
12+
from gpytorch.means import ConstantMeanGrad
13+
from gpytorch.test.base_test_case import BaseTestCase
14+
15+
# Simple training data
16+
num_train_samples = 15
17+
num_fantasies = 10
18+
dim = 1
19+
train_X = torch.linspace(0, 1, num_train_samples).reshape(-1, 1)
20+
train_Y = torch.hstack([
21+
torch.sin(train_X * (2 * pi)).reshape(-1, 1),
22+
(2 * pi) * torch.cos(train_X * (2 * pi)).reshape(-1, 1),
23+
])
24+
25+
26+
class GPWithDerivatives(gpytorch.models.ExactGP):
27+
def __init__(self, train_X, train_Y):
28+
likelihood = MultitaskGaussianLikelihood(num_tasks=1 + dim)
29+
super().__init__(train_X, train_Y, likelihood)
30+
self.mean_module = ConstantMeanGrad()
31+
self.base_kernel = RBFKernelGrad()
32+
self.covar_module = ScaleKernel(self.base_kernel)
33+
self._num_outputs = 1 + dim
34+
35+
def forward(self, x):
36+
mean_x = self.mean_module(x)
37+
covar_x = self.covar_module(x)
38+
return MultitaskMultivariateNormal(mean_x, covar_x)
39+
40+
41+
class TestDerivativeGPFutures(BaseTestCase, unittest.TestCase):
42+
43+
# Inspired by test_lanczos_fantasy_model
44+
def test_derivative_gp_futures(self):
45+
model = GPWithDerivatives(train_X, train_Y)
46+
mll = gpytorch.mlls.sum_marginal_log_likelihood.ExactMarginalLogLikelihood(model.likelihood, model)
47+
48+
mll.train()
49+
mll.eval()
50+
51+
# get a posterior to fill in caches
52+
model(torch.randn(num_train_samples).reshape(-1, 1))
53+
54+
new_x = torch.randn((1, 1, dim))
55+
new_y = torch.randn((num_fantasies, 1, 1, 1 + dim))
56+
57+
# just check that this can run without error
58+
model.get_fantasy_model(new_x, new_y)
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

0 commit comments

Comments
 (0)