@@ -151,6 +151,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
151151 num_tasks = full_output .event_shape [- 1 ]
152152 full_mean = full_mean .view (* batch_shape , - 1 , num_tasks )
153153 fant_mean = full_mean [..., (num_train // num_tasks ) :, :]
154+ full_targets = full_targets .view (* target_batch_shape , - 1 )
154155 else :
155156 full_mean = full_mean .view (* batch_shape , - 1 )
156157 fant_mean = full_mean [..., num_train :]
@@ -211,8 +212,6 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
211212 new_root = new_lt .root_decomposition ().root .to_dense ()
212213 new_covar_cache = new_lt .root_inv_decomposition ().root .to_dense ()
213214
214- full_targets = full_targets .view (* target_batch_shape , - 1 )
215-
216215 # Expand inputs accordingly if necessary (for fantasies at the same points)
217216 if full_inputs [0 ].dim () <= full_targets .dim ():
218217 fant_batch_shape = full_targets .shape [:1 ]
@@ -276,11 +275,7 @@ def train_shape(self):
276275
277276 def exact_prediction (self , joint_mean , joint_covar ):
278277 # Find the components of the distribution that contain test data
279- if not isinstance (self .train_prior_dist , MultitaskMultivariateNormal ):
280- test_mean = joint_mean [..., self .num_train :]
281- else :
282- num_tasks = joint_mean .shape [- 1 ]
283- test_mean = joint_mean [..., (self .num_train // num_tasks ) :, :]
278+ test_mean = joint_mean [..., self .num_train :]
284279 # For efficiency - we can make things more efficient
285280 if joint_covar .size (- 1 ) <= settings .max_eager_kernel_size .value ():
286281 test_covar = joint_covar [..., self .num_train :, :].to_dense ()
@@ -307,10 +302,11 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera
307302 # NOTE TO FUTURE SELF:
308303 # You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact
309304 # GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no!
310- if not isinstance (self .train_prior_dist , MultitaskMultivariateNormal ):
311- res = (test_train_covar @ self .mean_cache .unsqueeze (- 1 )).squeeze (- 1 )
305+
306+ if len (self .mean_cache .shape ) == 4 :
307+ res = (test_train_covar @ self .mean_cache .squeeze (1 ).unsqueeze (- 1 )).squeeze (- 1 )
312308 else :
313- res = (test_train_covar . unsqueeze ( 1 ) @ self .mean_cache .unsqueeze (- 1 )).squeeze (- 1 )
309+ res = (test_train_covar @ self .mean_cache .unsqueeze (- 1 )).squeeze (- 1 )
314310 res = res + test_mean
315311
316312 return res
0 commit comments