|
22 | 22 | from torch import Tensor |
23 | 23 |
|
24 | 24 | from .. import settings |
| 25 | + |
| 26 | +from ..distributions import MultitaskMultivariateNormal |
25 | 27 | from ..lazy import LazyEvaluatedKernelTensor |
26 | 28 | from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache |
27 | 29 |
|
@@ -134,16 +136,28 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ |
134 | 136 | A `DefaultPredictionStrategy` model with `n + m` training examples, where the `m` fantasy examples have |
135 | 137 | been added and all test-time caches have been updated. |
136 | 138 | """ |
| 139 | + if not isinstance(full_output, MultitaskMultivariateNormal): |
| 140 | + target_batch_shape = targets.shape[:-1] |
| 141 | + else: |
| 142 | + target_batch_shape = targets.shape[:-2] |
| 143 | + |
137 | 144 | full_mean, full_covar = full_output.mean, full_output.lazy_covariance_matrix |
138 | 145 |
|
139 | 146 | batch_shape = full_inputs[0].shape[:-2] |
140 | 147 |
|
141 | | - full_mean = full_mean.view(*batch_shape, -1) |
142 | 148 | num_train = self.num_train |
143 | 149 |
|
| 150 | + if isinstance(full_output, MultitaskMultivariateNormal): |
| 151 | + num_tasks = full_output.event_shape[-1] |
| 152 | + full_mean = full_mean.view(*batch_shape, -1, num_tasks) |
| 153 | + fant_mean = full_mean[..., (num_train // num_tasks) :, :] |
| 154 | + full_targets = full_targets.view(*target_batch_shape, -1) |
| 155 | + else: |
| 156 | + full_mean = full_mean.view(*batch_shape, -1) |
| 157 | + fant_mean = full_mean[..., num_train:] |
| 158 | + |
144 | 159 | # Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated. |
145 | 160 | fant_fant_covar = full_covar[..., num_train:, num_train:] |
146 | | - fant_mean = full_mean[..., num_train:] |
147 | 161 | mvn = self.train_prior_dist.__class__(fant_mean, fant_fant_covar) |
148 | 162 | fant_likelihood = self.likelihood.get_fantasy_likelihood(**kwargs) |
149 | 163 | mvn_obs = fant_likelihood(mvn, inputs, **kwargs) |
@@ -209,6 +223,9 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ |
209 | 223 | new_root = BatchRepeatLinearOperator(DenseLinearOperator(new_root), repeat_shape) |
210 | 224 | # no need to repeat the covar cache, broadcasting will do the right thing |
211 | 225 |
|
| 226 | + if isinstance(full_output, MultitaskMultivariateNormal): |
| 227 | + full_mean = full_mean.view(*target_batch_shape, -1, num_tasks).contiguous() |
| 228 | + |
212 | 229 | # Create new DefaultPredictionStrategy object |
213 | 230 | fant_strat = self.__class__( |
214 | 231 | train_inputs=full_inputs, |
@@ -285,7 +302,11 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera |
285 | 302 | # NOTE TO FUTURE SELF: |
286 | 303 | # You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact |
287 | 304 | # GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no! |
288 | | - res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1) |
| 305 | + |
| 306 | + if len(self.mean_cache.shape) == 4: |
| 307 | + res = (test_train_covar @ self.mean_cache.squeeze(1).unsqueeze(-1)).squeeze(-1) |
| 308 | + else: |
| 309 | + res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1) |
289 | 310 | res = res + test_mean |
290 | 311 |
|
291 | 312 | return res |
|
0 commit comments