@@ -21,36 +21,57 @@ def forward(self, x):
2121
2222
2323class TestLeaveOneOutPseudoLikelihood (unittest .TestCase ):
24- def get_data (self , shapes , dtype = None , device = None ):
24+ def get_data (self , shapes , combine_terms , dtype = None , device = None ):
2525 train_x = torch .rand (* shapes , dtype = dtype , device = device , requires_grad = True )
2626 train_y = torch .sin (train_x [..., 0 ]) + torch .cos (train_x [..., 1 ])
2727 likelihood = gpytorch .likelihoods .GaussianLikelihood ().to (dtype = dtype , device = device )
2828 model = ExactGPModel (train_x , train_y , likelihood ).to (dtype = dtype , device = device )
29- loocv = gpytorch .mlls .LeaveOneOutPseudoLikelihood (likelihood = likelihood , model = model )
29+ loocv = gpytorch .mlls .LeaveOneOutPseudoLikelihood (
30+ likelihood = likelihood ,
31+ model = model ,
32+ combine_terms = combine_terms
33+ )
3034 return train_x , train_y , loocv
3135
3236 def test_smoke (self ):
3337 """Make sure the loocv works without batching."""
34- train_x , train_y , loocv = self .get_data ([5 , 2 ])
38+ train_x , train_y , loocv = self .get_data ([5 , 2 ], combine_terms = True )
3539 output = loocv .model (train_x )
3640 loss = - loocv (output , train_y )
3741 loss .backward ()
3842 self .assertTrue (train_x .grad is not None )
3943
44+ train_x , train_y , loocv = self .get_data ([5 , 2 ], combine_terms = False )
45+ output = loocv .model (train_x )
46+ mll_out = loocv (output , train_y )
47+ loss = - 1 * sum (mll_out )
48+ loss .backward ()
49+ assert len (mll_out ) == 4
50+ self .assertTrue (train_x .grad is not None )
51+
4052 def test_smoke_batch (self ):
4153 """Make sure the loocv works without batching."""
42- train_x , train_y , loocv = self .get_data ([3 , 3 , 3 , 5 , 2 ])
54+ train_x , train_y , loocv = self .get_data ([3 , 3 , 3 , 5 , 2 ], combine_terms = True )
4355 output = loocv .model (train_x )
4456 loss = - loocv (output , train_y )
4557 assert loss .shape == (3 , 3 , 3 )
4658 loss .sum ().backward ()
4759 self .assertTrue (train_x .grad is not None )
4860
61+ train_x , train_y , loocv = self .get_data ([3 , 3 , 3 , 5 , 2 ], combine_terms = False )
62+ output = loocv .model (train_x )
63+ mll_out = loocv (output , train_y )
64+ loss = - 1 * sum (mll_out )
65+ assert len (mll_out ) == 4
66+ assert loss .shape == (3 , 3 , 3 )
67+ loss .sum ().backward ()
68+ self .assertTrue (train_x .grad is not None )
69+
4970 def test_check_bordered_system (self ):
5071 """Make sure that the bordered system solves match the naive solution."""
5172 n = 5
5273 # Compute the pseudo-likelihood via the bordered systems in O(n^3)
53- train_x , train_y , loocv = self .get_data ([n , 2 ], dtype = torch .float64 )
74+ train_x , train_y , loocv = self .get_data ([n , 2 ], combine_terms = True , dtype = torch .float64 )
5475 output = loocv .model (train_x )
5576 loocv_1 = loocv (output , train_y )
5677
0 commit comments