1010 AddedDiagLinearOperator ,
1111 BatchRepeatLinearOperator ,
1212 ConstantMulLinearOperator ,
13- DenseLinearOperator ,
1413 InterpolatedLinearOperator ,
1514 LinearOperator ,
1615 LowRankRootAddedDiagLinearOperator ,
@@ -211,8 +210,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
211210
212211 # now update the root and root inverse
213212 new_lt = self .lik_train_train_covar .cat_rows (fant_train_covar , fant_fant_covar )
214- new_root = new_lt .root_decomposition ().root . to_dense ()
215- new_covar_cache = new_lt .root_inv_decomposition ().root . to_dense ()
213+ new_root = new_lt .root_decomposition ().root
214+ new_covar_cache = new_lt .root_inv_decomposition ().root
216215
217216 # Expand inputs accordingly if necessary (for fantasies at the same points)
218217 if full_inputs [0 ].dim () <= full_targets .dim ():
@@ -222,7 +221,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
222221 full_inputs = [fi .expand (fant_batch_shape + fi .shape ) for fi in full_inputs ]
223222 full_mean = full_mean .expand (fant_batch_shape + full_mean .shape )
224223 full_covar = BatchRepeatLinearOperator (full_covar , repeat_shape )
225- new_root = BatchRepeatLinearOperator (DenseLinearOperator ( new_root ) , repeat_shape )
224+ new_root = BatchRepeatLinearOperator (new_root , repeat_shape )
226225 # no need to repeat the covar cache, broadcasting will do the right thing
227226
228227 if isinstance (full_output , MultitaskMultivariateNormal ):
@@ -238,7 +237,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_
238237 inv_root = new_covar_cache ,
239238 )
240239 add_to_cache (fant_strat , "mean_cache" , fant_mean_cache )
241- add_to_cache (fant_strat , "covar_cache" , new_covar_cache )
240+ add_to_cache (fant_strat , "covar_cache" , new_covar_cache . to_dense () )
242241 return fant_strat
243242
244243 @property
0 commit comments