2222NOISE = [0.127 , - 0.113 , - 0.345 , - 0.034 , - 0.069 , - 0.272 , 0.013 , 0.056 , 0.087 , - 0.081 ]
2323
2424MAX_ITER_MSG = "TOTAL NO. of ITERATIONS REACHED LIMIT"
25+ MAX_RETRY_MSG = "Fitting failed on all retries."
2526
2627
2728class TestFitGPyTorchModel (unittest .TestCase ):
@@ -40,10 +41,12 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
4041 for double in (False , True ):
4142 mll = self ._getModel (double = double , cuda = cuda )
4243 with warnings .catch_warnings (record = True ) as ws :
43- mll = fit_gpytorch_model (mll , optimizer = optimizer , options = options )
44+ mll = fit_gpytorch_model (
45+ mll , optimizer = optimizer , options = options , max_retries = 1
46+ )
4447 if optimizer == fit_gpytorch_scipy :
4548 self .assertEqual (len (ws ), 1 )
46- self .assertTrue (MAX_ITER_MSG in str (ws [- 1 ].message ))
49+ self .assertTrue (MAX_RETRY_MSG in str (ws [- 1 ].message ))
4750 model = mll .model
4851 # Make sure all of the parameters changed
4952 self .assertGreater (model .likelihood .raw_noise .abs ().item (), 1e-3 )
@@ -60,11 +63,12 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
6063 mll ,
6164 optimizer = optimizer ,
6265 options = options ,
66+ max_retries = 1 ,
6367 bounds = {"likelihood.noise_covar.raw_noise" : (1e-1 , None )},
6468 )
6569 if optimizer == fit_gpytorch_scipy :
6670 self .assertEqual (len (ws ), 1 )
67- self .assertTrue (MAX_ITER_MSG in str (ws [- 1 ].message ))
71+ self .assertTrue (MAX_RETRY_MSG in str (ws [- 1 ].message ))
6872
6973 model = mll .model
7074 self .assertGreaterEqual (model .likelihood .raw_noise .abs ().item (), 1e-1 )
@@ -100,10 +104,12 @@ def test_fit_gpytorch_model(self, cuda=False, optimizer=fit_gpytorch_scipy):
100104 ),
101105 )
102106 with warnings .catch_warnings (record = True ) as ws :
103- mll = fit_gpytorch_model (mll , optimizer = optimizer , options = options )
107+ mll = fit_gpytorch_model (
108+ mll , optimizer = optimizer , options = options , max_retries = 1
109+ )
104110 if optimizer == fit_gpytorch_scipy :
105111 self .assertEqual (len (ws ), 1 )
106- self .assertTrue (MAX_ITER_MSG in str (ws [- 1 ].message ))
112+ self .assertTrue (MAX_RETRY_MSG in str (ws [- 1 ].message ))
107113 self .assertTrue (mll .dummy_param .grad is None )
108114
109115 def test_fit_gpytorch_model_cuda (self ):
@@ -123,9 +129,10 @@ def test_fit_gpytorch_model_singular(self, cuda=False):
123129 mll = ExactMarginalLogLikelihood (gp .likelihood , gp )
124130 mll .to (device = device , dtype = dtype )
125131 with warnings .catch_warnings (record = True ) as ws :
132+ # this will do multiple retries
126133 fit_gpytorch_model (mll , options = options )
127134 self .assertEqual (len (ws ), 1 )
128- self .assertTrue ("Fitting failed" in str (ws [0 ].message ))
135+ self .assertTrue (MAX_RETRY_MSG in str (ws [0 ].message ))
129136
130137 def test_fit_gpytorch_model_singular_cuda (self ):
131138 if torch .cuda .is_available ():
0 commit comments