2222from torch import Tensor
2323
2424from .. import settings
25+
26+ from ..distributions import MultitaskMultivariateNormal
2527from ..lazy import LazyEvaluatedKernelTensor
2628from ..utils .memoize import add_to_cache , cached , clear_cache_hook , pop_from_cache
2729
@@ -134,16 +136,27 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
134136 A `DefaultPredictionStrategy` model with `n + m` training examples, where the `m` fantasy examples have
135137 been added and all test-time caches have been updated.
136138 """
139+ if not isinstance (full_output , MultitaskMultivariateNormal ):
140+ target_batch_shape = targets .shape [:- 1 ]
141+ else :
142+ target_batch_shape = targets .shape [:- 2 ]
143+
137144 full_mean , full_covar = full_output .mean , full_output .lazy_covariance_matrix
138145
139146 batch_shape = full_inputs [0 ].shape [:- 2 ]
140147
141- full_mean = full_mean .view (* batch_shape , - 1 )
142148 num_train = self .num_train
143149
150+ if isinstance (full_output , MultitaskMultivariateNormal ):
151+ num_tasks = full_output .event_shape [- 1 ]
152+ full_mean = full_mean .view (* batch_shape , - 1 , num_tasks )
153+ fant_mean = full_mean [..., (num_train // num_tasks ) :, :]
154+ else :
155+ full_mean = full_mean .view (* batch_shape , - 1 )
156+ fant_mean = full_mean [..., num_train :]
157+
144158 # Evaluate fant x train and fant x fant covariance matrices, leave train x train unevaluated.
145159 fant_fant_covar = full_covar [..., num_train :, num_train :]
146- fant_mean = full_mean [..., num_train :]
147160 mvn = self .train_prior_dist .__class__ (fant_mean , fant_fant_covar )
148161 fant_likelihood = self .likelihood .get_fantasy_likelihood (** kwargs )
149162 mvn_obs = fant_likelihood (mvn , inputs , ** kwargs )
@@ -198,6 +211,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
198211 new_root = new_lt .root_decomposition ().root .to_dense ()
199212 new_covar_cache = new_lt .root_inv_decomposition ().root .to_dense ()
200213
214+ full_targets = full_targets .view (* target_batch_shape , - 1 )
215+
201216 # Expand inputs accordingly if necessary (for fantasies at the same points)
202217 if full_inputs [0 ].dim () <= full_targets .dim ():
203218 fant_batch_shape = full_targets .shape [:1 ]
@@ -209,6 +224,9 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
209224 new_root = BatchRepeatLinearOperator (DenseLinearOperator (new_root ), repeat_shape )
210225 # no need to repeat the covar cache, broadcasting will do the right thing
211226
227+ if isinstance (full_output , MultitaskMultivariateNormal ):
228+ full_mean = full_mean .view (* target_batch_shape , - 1 , num_tasks ).contiguous ()
229+
212230 # Create new DefaultPredictionStrategy object
213231 fant_strat = self .__class__ (
214232 train_inputs = full_inputs ,
@@ -258,7 +276,11 @@ def train_shape(self):
258276
259277 def exact_prediction (self , joint_mean , joint_covar ):
260278 # Find the components of the distribution that contain test data
261- test_mean = joint_mean [..., self .num_train :]
279+ if not isinstance (self .train_prior_dist , MultitaskMultivariateNormal ):
280+ test_mean = joint_mean [..., self .num_train :]
281+ else :
282+ num_tasks = joint_mean .shape [- 1 ]
283+ test_mean = joint_mean [..., (self .num_train // num_tasks ) :, :]
262284 # For efficiency - we can make things more efficient
263285 if joint_covar .size (- 1 ) <= settings .max_eager_kernel_size .value ():
264286 test_covar = joint_covar [..., self .num_train :, :].to_dense ()
@@ -285,7 +307,10 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera
285307 # NOTE TO FUTURE SELF:
286308 # You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact
287309 # GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no!
288- res = (test_train_covar @ self .mean_cache .unsqueeze (- 1 )).squeeze (- 1 )
310+ if not isinstance (self .train_prior_dist , MultitaskMultivariateNormal ):
311+ res = (test_train_covar @ self .mean_cache .unsqueeze (- 1 )).squeeze (- 1 )
312+ else :
313+ res = (test_train_covar .unsqueeze (1 ) @ self .mean_cache .unsqueeze (- 1 )).squeeze (- 1 )
289314 res = res + test_mean
290315
291316 return res
0 commit comments