Skip to content

Commit 969a9ec

Browse files
committed
qKG First commit
1 parent a2b5fd8 commit 969a9ec

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed

gpytorch/models/exact_gp.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from .. import settings
9-
from ..distributions import MultivariateNormal
9+
from ..distributions import MultitaskMultivariateNormal, MultivariateNormal
1010
from ..likelihoods import _GaussianLikelihoodBase
1111
from ..utils.generic import length_safe_zip
1212
from ..utils.warnings import GPInputWarning
@@ -162,15 +162,17 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
162162

163163
model_batch_shape = self.train_inputs[0].shape[:-2]
164164

165-
if self.train_targets.dim() > len(model_batch_shape) + 1:
166-
raise RuntimeError("Cannot yet add fantasy observations to multitask GPs, but this is coming soon!")
167-
168165
if not isinstance(inputs, list):
169166
inputs = [inputs]
170167

171168
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in inputs]
172169

173-
target_batch_shape = targets.shape[:-1]
170+
if not isinstance(self.prediction_strategy.train_prior_dist, MultitaskMultivariateNormal):
171+
data_dim_start = -1
172+
else:
173+
data_dim_start = -2
174+
175+
target_batch_shape = targets.shape[:data_dim_start]
174176
input_batch_shape = inputs[0].shape[:-2]
175177
tbdim, ibdim = len(target_batch_shape), len(input_batch_shape)
176178

@@ -198,7 +200,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
198200
# computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
199201
# size of the fantasy model here - this is done below, after the evaluation and fast fantasy update
200202
train_inputs = [tin.expand(input_batch_shape + tin.shape[-2:]) for tin in self.train_inputs]
201-
train_targets = self.train_targets.expand(target_batch_shape + self.train_targets.shape[-1:])
203+
train_targets = self.train_targets.expand(target_batch_shape + self.train_targets.shape[data_dim_start:])
202204

203205
full_inputs = [
204206
torch.cat(
@@ -208,8 +210,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
208210
for train_input, input in length_safe_zip(train_inputs, inputs)
209211
]
210212
full_targets = torch.cat(
211-
[train_targets, targets.expand(target_batch_shape + targets.shape[-1:])],
212-
dim=-1,
213+
[train_targets, targets.expand(target_batch_shape + targets.shape[data_dim_start:])], dim=data_dim_start
213214
)
214215

215216
try:
@@ -316,7 +317,7 @@ def __call__(self, *args, **kwargs):
316317
if settings.debug().on():
317318
if not isinstance(full_output, MultivariateNormal):
318319
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
319-
full_mean, full_covar = full_output.loc, full_output.lazy_covariance_matrix
320+
full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix
320321

321322
# Determine the shape of the joint distribution
322323
batch_shape = full_output.batch_shape

gpytorch/models/exact_prediction_strategies.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from torch import Tensor
2323

2424
from .. import settings
25+
26+
from ..distributions import MultitaskMultivariateNormal
2527
from ..lazy import LazyEvaluatedKernelTensor
2628
from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache
2729

@@ -134,16 +136,27 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
134136
A `DefaultPredictionStrategy` model with `n + m` training examples, where the `m` fantasy examples have
135137
been added and all test-time caches have been updated.
136138
"""
139+
if not isinstance(full_output, MultitaskMultivariateNormal):
140+
target_batch_shape = targets.shape[:-1]
141+
else:
142+
target_batch_shape = targets.shape[:-2]
143+
137144
full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix
138145

139146
batch_shape = full_inputs[0].shape[:-2]
140147

141-
full_mean = full_mean.view(*batch_shape, -1)
142148
num_train = self.num_train
143149

150+
if isinstance(full_output, MultitaskMultivariateNormal):
151+
num_tasks = full_output.event_shape[-1]
152+
full_mean = full_mean.view(*batch_shape, -1, num_tasks)
153+
fant_mean = full_mean[..., (num_train // num_tasks) :, :]
154+
else:
155+
full_mean = full_mean.view(*batch_shape, -1)
156+
fant_mean = full_mean[..., num_train:]
157+
144158
# Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated.
145159
fant_fant_covar = full_covar[..., num_train:, num_train:]
146-
fant_mean = full_mean[..., num_train:]
147160
mvn = self.train_prior_dist.__class__(fant_mean, fant_fant_covar)
148161
fant_likelihood = self.likelihood.get_fantasy_likelihood(**kwargs)
149162
mvn_obs = fant_likelihood(mvn, inputs, **kwargs)
@@ -198,6 +211,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
198211
new_root = new_lt.root_decomposition().root.to_dense()
199212
new_covar_cache = new_lt.root_inv_decomposition().root.to_dense()
200213

214+
full_targets = full_targets.view(*target_batch_shape, -1)
215+
201216
# Expand inputs accordingly if necessary (for fantasies at the same points)
202217
if full_inputs[0].dim() <= full_targets.dim():
203218
fant_batch_shape = full_targets.shape[:1]
@@ -209,6 +224,9 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
209224
new_root = BatchRepeatLinearOperator(DenseLinearOperator(new_root), repeat_shape)
210225
# no need to repeat the covar cache, broadcasting will do the right thing
211226

227+
if isinstance(full_output, MultitaskMultivariateNormal):
228+
full_mean = full_mean.view(*target_batch_shape, -1, num_tasks).contiguous()
229+
212230
# Create new DefaultPredictionStrategy object
213231
fant_strat = self.__class__(
214232
train_inputs=full_inputs,
@@ -258,7 +276,11 @@ def train_shape(self):
258276

259277
def exact_prediction(self, joint_mean, joint_covar):
260278
# Find the components of the distribution that contain test data
261-
test_mean = joint_mean[..., self.num_train :]
279+
if not isinstance(self.train_prior_dist, MultitaskMultivariateNormal):
280+
test_mean = joint_mean[..., self.num_train :]
281+
else:
282+
num_tasks = joint_mean.shape[-1]
283+
test_mean = joint_mean[..., (self.num_train // num_tasks) :, :]
262284
# For efficiency - we can make things more efficient
263285
if joint_covar.size(-1) <= settings.max_eager_kernel_size.value():
264286
test_covar = joint_covar[..., self.num_train :, :].to_dense()
@@ -285,7 +307,10 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera
285307
# NOTE TO FUTURE SELF:
286308
# You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact
287309
# GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no!
288-
res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1)
310+
if not isinstance(self.train_prior_dist, MultitaskMultivariateNormal):
311+
res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1)
312+
else:
313+
res = (test_train_covar.unsqueeze(1) @ self.mean_cache.unsqueeze(-1)).squeeze(-1)
289314
res = res + test_mean
290315

291316
return res

0 commit comments

Comments
 (0)