3737from botorch .posteriors .gpytorch import GPyTorchPosterior
3838from botorch .posteriors .posterior import Posterior
3939from gpytorch import settings
40- from gpytorch .constraints import GreaterThan
40+ from gpytorch .constraints import GreaterThan , Interval
4141from gpytorch .distributions .multivariate_normal import MultivariateNormal
4242from gpytorch .kernels .rbf_kernel import RBFKernel
4343from gpytorch .kernels .scale_kernel import ScaleKernel
@@ -147,7 +147,7 @@ def __init__(
147147
148148 # Set optional parameters
149149 # Explicitly set jitter for numerical stability in psd_safe_cholesky
150- self ._jitter = kwargs .get ("jitter" , 1e-5 )
150+ self ._jitter = kwargs .get ("jitter" , 1e-6 )
151151 # Stopping creteria in scipy.optimize.fsolve used to find f_map in _update()
152152 # If None, set to 1e-6 by default in _update
153153 self ._xtol = kwargs .get ("xtol" )
@@ -170,6 +170,7 @@ def __init__(
170170 # estimates away from scale value that would make Phi(f(x)) saturate
171171 # at 0 or 1
172172 if covar_module is None :
173+ os_lb , os_ub = 1e-2 , 1e2
173174 ls_prior = GammaPrior (1.2 , 0.5 )
174175 ls_prior_mode = (ls_prior .concentration - 1 ) / ls_prior .rate
175176 covar_module = ScaleKernel (
@@ -181,9 +182,16 @@ def __init__(
181182 lower_bound = 1e-4 , transform = None , initial_value = ls_prior_mode
182183 ),
183184 ),
184- outputscale_prior = SmoothedBoxPrior (a = 1 , b = 4 ),
185+ outputscale_prior = SmoothedBoxPrior (a = os_lb , b = os_ub ),
186+ # make sure we won't get extreme values for the output scale
187+ outputscale_constraint = Interval (
188+ lower_bound = os_lb * 0.5 ,
189+ upper_bound = os_ub * 2.0 ,
190+ initial_value = 1.0 ,
191+ ),
185192 )
186-
193+ if not isinstance (covar_module , ScaleKernel ):
194+ raise UnsupportedError ("PairwiseGP must be used with a ScaleKernel." )
187195 self .covar_module = covar_module
188196
189197 self ._x0 = None # will store temporary results for warm-starting
@@ -225,6 +233,16 @@ def __deepcopy__(self, memo) -> PairwiseGP:
225233 self .__deepcopy__ = dcp
226234 return new_model
227235
236+ def _scaled_psd_safe_cholesky (
237+ self , M : Tensor , jitter : Optional [float ] = None
238+ ) -> Tensor :
239+ r"""scale M by 1/outputscale before cholesky for better numerical stability"""
240+ scale = self .covar_module .outputscale .unsqueeze (- 1 ).unsqueeze (- 1 )
241+ M = M / scale
242+ chol = psd_safe_cholesky (M , jitter = jitter )
243+ chol = chol * scale .sqrt ()
244+ return chol
245+
228246 def _has_no_data (self ):
229247 r"""Return true if the model does not have both datapoints and comparisons"""
230248 return (
@@ -238,24 +256,6 @@ def _calc_covar(self, X1: Tensor, X2: Tensor) -> Union[Tensor, LinearOperator]:
238256 covar = self .covar_module (X1 , X2 )
239257 return covar .to_dense ()
240258
241- def _batch_chol_inv (self , mat_chol : Tensor ) -> Tensor :
242- r"""Wrapper to perform (batched) cholesky inverse"""
243- # TODO: get rid of this once cholesky_inverse supports batch mode
244- batch_eye = torch .eye (
245- mat_chol .shape [- 1 ],
246- dtype = self .datapoints .dtype ,
247- device = self .datapoints .device ,
248- )
249-
250- if len (mat_chol .shape ) == 2 :
251- mat_inv = torch .cholesky_inverse (mat_chol )
252- elif len (mat_chol .shape ) > 2 and (mat_chol .shape [- 1 ] == mat_chol .shape [- 2 ]):
253- batch_eye = batch_eye .repeat (* (mat_chol .shape [:- 2 ]), 1 , 1 )
254- chol_inv = torch .linalg .solve_triangular (mat_chol , batch_eye , upper = False )
255- mat_inv = chol_inv .transpose (- 1 , - 2 ) @ chol_inv
256-
257- return mat_inv
258-
259259 def _update_covar (self , datapoints : Tensor ) -> None :
260260 r"""Update values derived from the data and hyperparameters
261261
@@ -265,8 +265,10 @@ def _update_covar(self, datapoints: Tensor) -> None:
265265 datapoints: (Transformed) datapoints for finding f_max
266266 """
267267 self .covar = self ._calc_covar (datapoints , datapoints )
268- self .covar_chol = psd_safe_cholesky (self .covar , jitter = self ._jitter )
269- self .covar_inv = self ._batch_chol_inv (self .covar_chol )
268+ self .covar_chol = self ._scaled_psd_safe_cholesky (
269+ self .covar , jitter = self ._jitter
270+ )
271+ self .covar_inv = torch .cholesky_inverse (self .covar_chol )
270272
271273 def _prior_mean (self , X : Tensor ) -> Union [Tensor , LinearOperator ]:
272274 r"""Return point prediction using prior only
@@ -417,7 +419,17 @@ def _update(self, datapoints: Tensor, **kwargs) -> None:
417419 # warm start
418420 init_x0_size = self .batch_shape + torch .Size ([self .n ])
419421 if self ._x0 is None or torch .Size (self ._x0 .shape ) != init_x0_size :
420- x0 = np .random .rand (* init_x0_size )
422+ sqrt_scale = (
423+ self .covar_module .outputscale .sqrt ()
424+ .unsqueeze (- 1 )
425+ .detach ()
426+ .cpu ()
427+ .numpy ()
428+ )
429+ # initialize x0 using std normal but clip by 3 std to keep it bounded
430+ x0 = np .random .standard_normal (init_x0_size ).clip (min = - 3 , max = 3 )
431+ # scale x0 to be on roughly the right scale
432+ x0 = x0 * sqrt_scale
421433 else :
422434 x0 = self ._x0
423435
@@ -755,7 +767,6 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
755767 2. Prior predictions (prior mode)
756768 3. Predictive posterior (eval mode)
757769 """
758-
759770 # Training mode: optimizing
760771 if self .training :
761772 if self ._has_no_data ():
@@ -839,7 +850,7 @@ def forward(self, datapoints: Tensor) -> MultivariateNormal:
839850 # output_covar is sometimes non-PSD
840851 # perform a cholesky decomposition to check and amend
841852 covariance_matrix = RootLinearOperator (
842- psd_safe_cholesky (output_covar , jitter = self ._jitter )
853+ self . _scaled_psd_safe_cholesky (output_covar , jitter = self ._jitter )
843854 ),
844855 )
845856 return post
0 commit comments