33import os
44import random
55import unittest
6+ from unittest .mock import MagicMock , patch
67import warnings
78from math import exp , pi
89
1415from gpytorch .means import ConstantMean
1516from gpytorch .priors import SmoothedBoxPrior
1617from gpytorch .test .utils import least_used_cuda_device
18+ from gpytorch .utils .cholesky import CHOLESKY_METHOD
1719from gpytorch .utils .warnings import NumericalWarning
1820from torch import optim
1921
@@ -65,41 +67,56 @@ def tearDown(self):
6567 if hasattr (self , "rng_state" ):
6668 torch .set_rng_state (self .rng_state )
6769
68- def test_sgpr_mean_abs_error (self ):
70+ def test_sgpr_mean_abs_error (self , cuda = False ):
6971 # Suppress numerical warnings
7072 warnings .simplefilter ("ignore" , NumericalWarning )
7173
72- train_x , train_y , test_x , test_y = make_data ()
74+ train_x , train_y , test_x , test_y = make_data (cuda = cuda )
7375 likelihood = GaussianLikelihood ()
7476 gp_model = GPRegressionModel (train_x , train_y , likelihood )
7577 mll = gpytorch .mlls .ExactMarginalLogLikelihood (likelihood , gp_model )
7678
77- # Optimize the model
78- gp_model . train ()
79- likelihood . train ()
79+ if cuda :
80+ gp_model = gp_model . cuda ()
81+ likelihood = likelihood . cuda ()
8082
81- optimizer = optim .Adam (gp_model .parameters (), lr = 0.1 )
82- for _ in range (30 ):
83- optimizer .zero_grad ()
84- output = gp_model (train_x )
85- loss = - mll (output , train_y )
86- loss .backward ()
87- optimizer .step ()
83+ # Mock cholesky
84+ _wrapped_cholesky = MagicMock (
85+ wraps = torch .linalg .cholesky if CHOLESKY_METHOD == "torch.linalg.cholesky" else torch .linalg .cholesky_ex
86+ )
87+ with patch (CHOLESKY_METHOD , new = _wrapped_cholesky ) as cholesky_mock :
8888
89- # Check that we have the right LazyTensor type
90- kernel = likelihood ( gp_model ( train_x )). lazy_covariance_matrix . evaluate_kernel ()
91- self . assertIsInstance ( kernel , gpytorch . lazy . LowRankRootAddedDiagLazyTensor )
89+ # Optimize the model
90+ gp_model . train ()
91+ likelihood . train ( )
9292
93- for param in gp_model .parameters ():
94- self .assertTrue (param .grad is not None )
95- self .assertGreater (param .grad .norm ().item (), 0 )
93+ optimizer = optim .Adam (gp_model .parameters (), lr = 0.1 )
94+ for _ in range (30 ):
95+ optimizer .zero_grad ()
96+ output = gp_model (train_x )
97+ loss = - mll (output , train_y )
98+ loss .backward ()
99+ optimizer .step ()
96100
97- # Test the model
98- gp_model . eval ()
99- likelihood . eval ( )
101+ # Check that we have the right LazyTensor type
102+ kernel = likelihood ( gp_model ( train_x )). lazy_covariance_matrix . evaluate_kernel ()
103+ self . assertIsInstance ( kernel , gpytorch . lazy . LowRankRootAddedDiagLazyTensor )
100104
101- test_preds = likelihood (gp_model (test_x )).mean
102- mean_abs_error = torch .mean (torch .abs (test_y - test_preds ))
105+ for param in gp_model .parameters ():
106+ self .assertTrue (param .grad is not None )
107+ self .assertGreater (param .grad .norm ().item (), 0 )
108+
109+ # Test the model
110+ gp_model .eval ()
111+ likelihood .eval ()
112+
113+ test_preds = likelihood (gp_model (test_x )).mean
114+ mean_abs_error = torch .mean (torch .abs (test_y - test_preds ))
115+ cholesky_mock .assert_called () # We SHOULD call Cholesky...
116+ for chol_arg in cholesky_mock .call_args_list :
117+ first_arg = chol_arg [0 ][0 ]
118+ self .assertTrue (torch .is_tensor (first_arg ))
119+ self .assertTrue (first_arg .size (- 1 ) == gp_model .covar_module .inducing_points .size (- 2 ))
103120
104121 self .assertLess (mean_abs_error .squeeze ().item (), 0.1 )
105122
@@ -123,62 +140,9 @@ def test_sgpr_mean_abs_error_cuda(self):
123140
124141 if not torch .cuda .is_available ():
125142 return
126- with least_used_cuda_device ():
127- train_x , train_y , test_x , test_y = make_data (cuda = True )
128- likelihood = GaussianLikelihood ().cuda ()
129- gp_model = GPRegressionModel (train_x , train_y , likelihood ).cuda ()
130- mll = gpytorch .mlls .ExactMarginalLogLikelihood (likelihood , gp_model )
131-
132- # Test the model before optimization
133- gp_model .eval ()
134- likelihood .eval ()
135- test_preds = likelihood (gp_model (test_x )).mean
136- mean_abs_error = torch .mean (torch .abs (test_y - test_preds ))
137- self .assertLess (mean_abs_error .squeeze ().item (), 0.02 )
138-
139- # Test variances before optimization
140- test_vars = likelihood (gp_model (test_x )).variance
141- self .assertAllClose (test_vars , likelihood (gp_model (test_x )).covariance_matrix .diagonal (dim1 = - 1 , dim2 = - 2 ))
142- self .assertGreater (test_vars .min ().item () + 0.1 , likelihood .noise .item ())
143- self .assertLess (
144- test_vars .max ().item () - 0.05 ,
145- likelihood .noise .item () + gp_model .covar_module .base_kernel .outputscale .item ()
146- )
147-
148- # Optimize the model
149- gp_model .train ()
150- likelihood .train ()
151-
152- optimizer = optim .Adam (gp_model .parameters (), lr = 0.1 )
153- optimizer .n_iter = 0
154- for _ in range (25 ):
155- optimizer .zero_grad ()
156- output = gp_model (train_x )
157- loss = - mll (output , train_y )
158- loss .backward ()
159- optimizer .n_iter += 1
160- optimizer .step ()
161143
162- for param in gp_model .parameters ():
163- self .assertTrue (param .grad is not None )
164- self .assertGreater (param .grad .norm ().item (), 0 )
165-
166- # Test the model
167- gp_model .eval ()
168- likelihood .eval ()
169- test_preds = likelihood (gp_model (test_x )).mean
170- mean_abs_error = torch .mean (torch .abs (test_y - test_preds ))
171-
172- self .assertLess (mean_abs_error .squeeze ().item (), 0.02 )
173-
174- # Test variances
175- test_vars = likelihood (gp_model (test_x )).variance
176- self .assertAllClose (test_vars , likelihood (gp_model (test_x )).covariance_matrix .diagonal (dim1 = - 1 , dim2 = - 2 ))
177- self .assertGreater (test_vars .min ().item () + 0.1 , likelihood .noise .item ())
178- self .assertLess (
179- test_vars .max ().item () - 0.05 ,
180- likelihood .noise .item () + gp_model .covar_module .base_kernel .outputscale .item ()
181- )
144+ with least_used_cuda_device ():
145+ self .test_sgpr_mean_abs_error (cuda = True )
182146
183147
184148if __name__ == "__main__" :
0 commit comments