|
17 | 17 | fit_gpytorch_mll,
|
18 | 18 | get_fitted_map_saas_ensemble,
|
19 | 19 | get_fitted_map_saas_model,
|
| 20 | + logger, |
20 | 21 | )
|
21 |
| -from botorch.models import SaasFullyBayesianSingleTaskGP, SingleTaskGP |
| 22 | +from botorch.models import SingleTaskGP |
22 | 23 | from botorch.models.map_saas import (
|
23 | 24 | add_saas_prior,
|
24 | 25 | AdditiveMapSaasSingleTaskGP,
|
@@ -291,93 +292,24 @@ def test_get_saas_model(self) -> None:
|
291 | 292 | self.assertTrue(loss < loss_short)
|
292 | 293 |
|
293 | 294 | def test_get_saas_ensemble(self) -> None:
|
294 |
| - for infer_noise, taus in itertools.product([True, False], [None, [0.1, 0.2]]): |
295 |
| - tkwargs = {"device": self.device, "dtype": torch.double} |
296 |
| - train_X, train_Y, _ = self._get_data_hardcoded(**tkwargs) |
297 |
| - d = train_X.shape[-1] |
298 |
| - train_Yvar = ( |
299 |
| - None |
300 |
| - if infer_noise |
301 |
| - else 0.1 * torch.arange(len(train_X), **tkwargs).unsqueeze(-1) |
302 |
| - ) |
303 |
| - # Fit without specifying tau |
304 |
| - with torch.random.fork_rng(): |
305 |
| - torch.manual_seed(0) |
306 |
| - model = get_fitted_map_saas_ensemble( |
307 |
| - train_X=train_X, |
308 |
| - train_Y=train_Y, |
309 |
| - train_Yvar=train_Yvar, |
310 |
| - input_transform=Normalize(d=d), |
311 |
| - outcome_transform=Standardize(m=1), |
312 |
| - taus=taus, |
313 |
| - ) |
314 |
| - self.assertIsInstance(model, SaasFullyBayesianSingleTaskGP) |
315 |
| - num_taus = 4 if taus is None else len(taus) |
316 |
| - self.assertEqual( |
317 |
| - model.covar_module.base_kernel.lengthscale.shape, |
318 |
| - torch.Size([num_taus, 1, d]), |
319 |
| - ) |
320 |
| - self.assertEqual(model.batch_shape, torch.Size([num_taus])) |
321 |
| - # Make sure the lengthscales are reasonable |
322 |
| - self.assertGreater( |
323 |
| - model.covar_module.base_kernel.lengthscale[..., 1:].min(), 50 |
324 |
| - ) |
325 |
| - self.assertLess( |
326 |
| - model.covar_module.base_kernel.lengthscale[..., 0].max(), 10 |
327 |
| - ) |
328 |
| - |
329 |
| - # testing optimizer_options: short optimization run with maxiter = 3 |
330 |
| - with torch.random.fork_rng(): |
331 |
| - torch.manual_seed(0) |
332 |
| - fit_gpytorch_mll_mock = mock.Mock(wraps=fit_gpytorch_mll) |
333 |
| - with mock.patch( |
334 |
| - "botorch.fit.fit_gpytorch_mll", |
335 |
| - new=fit_gpytorch_mll_mock, |
336 |
| - ): |
337 |
| - maxiter = 3 |
338 |
| - model_short = get_fitted_map_saas_ensemble( |
339 |
| - train_X=train_X, |
340 |
| - train_Y=train_Y, |
341 |
| - train_Yvar=train_Yvar, |
342 |
| - input_transform=Normalize(d=d), |
343 |
| - outcome_transform=Standardize(m=1), |
344 |
| - taus=taus, |
345 |
| - optimizer_kwargs={"options": {"maxiter": maxiter}}, |
346 |
| - ) |
347 |
| - kwargs = fit_gpytorch_mll_mock.call_args.kwargs |
348 |
| - # fit_gpytorch_mll has "option" kwarg, not "optimizer_options" |
349 |
| - self.assertEqual( |
350 |
| - kwargs["optimizer_kwargs"]["options"]["maxiter"], maxiter |
351 |
| - ) |
352 |
| - |
353 |
| - # compute sum of marginal likelihoods of ensemble after short run |
354 |
| - # NOTE: We can't put MLL in train mode here since |
355 |
| - # SaasFullyBayesianSingleTaskGP requires NUTS for training. |
356 |
| - mll_short = ExactMarginalLogLikelihood( |
357 |
| - model=model_short, likelihood=model_short.likelihood |
| 295 | + train_X, train_Y, _ = self._get_data_hardcoded(device=self.device) |
| 296 | + with self.assertLogs(logger=logger, level="WARNING") as logs, mock.patch( |
| 297 | + "botorch.fit.fit_gpytorch_mll" |
| 298 | + ) as mock_fit: |
| 299 | + model = get_fitted_map_saas_ensemble( |
| 300 | + train_X=train_X, |
| 301 | + train_Y=train_Y, |
| 302 | + input_transform=Normalize(d=train_X.shape[-1]), |
| 303 | + outcome_transform=Standardize(m=1, batch_shape=torch.Size([4])), |
| 304 | + optimizer_kwargs={"options": {"maxiter": 3}}, |
358 | 305 | )
|
359 |
| - train_inputs = mll_short.model.train_inputs |
360 |
| - train_targets = mll_short.model.train_targets |
361 |
| - loss_short = -mll_short(model_short(*train_inputs), train_targets) |
362 |
| - # compute sum of marginal likelihoods of ensemble after standard run |
363 |
| - mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood) |
364 |
| - # reusing train_inputs and train_targets, since the transforms are the same |
365 |
| - loss = -mll(model(*train_inputs), train_targets) |
366 |
| - # the longer running optimization should have smaller loss than the shorter |
367 |
| - self.assertLess((loss - loss_short).max(), 0.0) |
368 |
| - |
369 |
| - # test error message |
370 |
| - with self.assertRaisesRegex( |
371 |
| - ValueError, "if you only specify one value of tau" |
372 |
| - ): |
373 |
| - model_short = get_fitted_map_saas_ensemble( |
374 |
| - train_X=train_X, |
375 |
| - train_Y=train_Y, |
376 |
| - train_Yvar=train_Yvar, |
377 |
| - input_transform=Normalize(d=d), |
378 |
| - outcome_transform=Standardize(m=1), |
379 |
| - taus=[0.1], |
380 |
| - ) |
| 306 | + self.assertTrue( |
| 307 | + any("use EnsembleMapSaasGP instead" in output for output in logs.output) |
| 308 | + ) |
| 309 | + self.assertEqual( |
| 310 | + mock_fit.call_args.kwargs["optimizer_kwargs"], {"options": {"maxiter": 3}} |
| 311 | + ) |
| 312 | + self.assertIsInstance(model, EnsembleMapSaasGP) |
381 | 313 |
|
382 | 314 | def test_input_transform_in_train(self) -> None:
|
383 | 315 | train_X, train_Y, test_X = self._get_data()
|
@@ -514,7 +446,7 @@ def test_batch_model_fitting(self) -> None:
|
514 | 446 |
|
515 | 447 | @mock_optimize
|
516 | 448 | def test_emsemble_map_saas(self) -> None:
|
517 |
| - train_X, train_Y, test_X = self._get_data() |
| 449 | + train_X, train_Y, test_X = self._get_data(device=self.device) |
518 | 450 | d = train_X.shape[-1]
|
519 | 451 | num_taus = 8
|
520 | 452 | for with_options in (False, True):
|
|
0 commit comments