Skip to content

Commit b9dc064

Browse files
committed
Passing unit tests now
1 parent 969a9ec commit b9dc064

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

gpytorch/models/exact_gp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def __call__(self, *args, **kwargs):
317317
if settings.debug().on():
318318
if not isinstance(full_output, MultivariateNormal):
319319
raise RuntimeError("ExactGP.forward must return a MultivariateNormal")
320-
full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix
320+
full_mean, full_covar = full_output.loc, full_output.lazy_covariance_matrix
321321

322322
# Determine the shape of the joint distribution
323323
batch_shape = full_output.batch_shape

gpytorch/models/exact_prediction_strategies.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
151151
num_tasks = full_output.event_shape[-1]
152152
full_mean = full_mean.view(*batch_shape, -1, num_tasks)
153153
fant_mean = full_mean[..., (num_train // num_tasks) :, :]
154+
full_targets = full_targets.view(*target_batch_shape, -1)
154155
else:
155156
full_mean = full_mean.view(*batch_shape, -1)
156157
fant_mean = full_mean[..., num_train:]
@@ -211,8 +212,6 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
211212
new_root = new_lt.root_decomposition().root.to_dense()
212213
new_covar_cache = new_lt.root_inv_decomposition().root.to_dense()
213214

214-
full_targets = full_targets.view(*target_batch_shape, -1)
215-
216215
# Expand inputs accordingly if necessary (for fantasies at the same points)
217216
if full_inputs[0].dim() <= full_targets.dim():
218217
fant_batch_shape = full_targets.shape[:1]
@@ -276,11 +275,7 @@ def train_shape(self):
276275

277276
def exact_prediction(self, joint_mean, joint_covar):
278277
# Find the components of the distribution that contain test data
279-
if not isinstance(self.train_prior_dist, MultitaskMultivariateNormal):
280-
test_mean = joint_mean[..., self.num_train :]
281-
else:
282-
num_tasks = joint_mean.shape[-1]
283-
test_mean = joint_mean[..., (self.num_train // num_tasks) :, :]
278+
test_mean = joint_mean[..., self.num_train :]
284279
# For efficiency - we can make things more efficient
285280
if joint_covar.size(-1) <= settings.max_eager_kernel_size.value():
286281
test_covar = joint_covar[..., self.num_train :, :].to_dense()
@@ -307,10 +302,11 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera
307302
# NOTE TO FUTURE SELF:
308303
# You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact
309304
# GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no!
310-
if not isinstance(self.train_prior_dist, MultitaskMultivariateNormal):
311-
res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1)
305+
306+
if len(self.mean_cache.shape) == 4:
307+
res = (test_train_covar @ self.mean_cache.squeeze(1).unsqueeze(-1)).squeeze(-1)
312308
else:
313-
res = (test_train_covar.unsqueeze(1) @ self.mean_cache.unsqueeze(-1)).squeeze(-1)
309+
res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1)
314310
res = res + test_mean
315311

316312
return res

0 commit comments

Comments
 (0)