|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 |
|
| 6 | +from typing import Any |
6 | 7 |
|
7 | 8 | import torch
|
| 9 | +from botorch.acquisition.objective import PosteriorTransform |
8 | 10 | from botorch.exceptions import UnsupportedError
|
9 | 11 | from botorch.models.gp_regression import SingleTaskGP
|
10 | 12 | from botorch.models.transforms.input import InputTransform
|
11 | 13 | from botorch.models.transforms.outcome import OutcomeTransform
|
12 | 14 | from botorch.models.utils.gpytorch_modules import (
|
13 | 15 | get_gaussian_likelihood_with_lognormal_prior,
|
14 | 16 | )
|
| 17 | +from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM |
15 | 18 | from botorch.utils.constraints import LogTransformedInterval
|
16 | 19 | from botorch.utils.types import _DefaultType, DEFAULT
|
17 | 20 | from gpytorch.constraints import Interval
|
|
30 | 33 | class SaasPriorHelper:
|
31 | 34 | """Helper class for specifying parameter and setting closures."""
|
32 | 35 |
|
33 |
| - def __init__(self, tau: float | None = None): |
| 36 | + def __init__(self, tau: Tensor | float | None = None): |
34 | 37 | """Instantiates a new helper object.
|
35 | 38 |
|
36 | 39 | Args:
|
37 | 40 | tau: Value of the global shrinkage parameter. If `None`, the tau will be
|
38 | 41 | a free parameter and inferred from the data.
|
| 42 | + Tau can be a tensor for batched models, like `EnsembleMapSaasGP`, |
| 43 | + where each batch has a different sparsity prior. If tau is a tensor, |
| 44 | + it must have shape `batch_shape`. |
39 | 45 | """
|
40 | 46 | self._tau = torch.as_tensor(tau) if tau is not None else None
|
41 | 47 |
|
@@ -102,10 +108,8 @@ def tau_prior_setting_closure(self, m: Kernel, value: Tensor) -> None:
|
102 | 108 | """
|
103 | 109 | lb = m.raw_tau_constraint.lower_bound.to(m.raw_tau)
|
104 | 110 | ub = m.raw_tau_constraint.upper_bound.to(m.raw_tau)
|
105 |
| - m.raw_tau.data.fill_( |
106 |
| - m.raw_tau_constraint.inverse_transform( |
107 |
| - value.to(m.raw_tau).clamp(lb + EPS, ub - EPS) |
108 |
| - ).item() |
| 111 | + m.raw_tau.data = m.raw_tau_constraint.inverse_transform( |
| 112 | + value.to(m.raw_tau).clamp(lb + EPS, ub - EPS) |
109 | 113 | )
|
110 | 114 |
|
111 | 115 |
|
@@ -218,7 +222,7 @@ def get_map_saas_model(
|
218 | 222 | )
|
219 | 223 | # NOTE: need to call `to` to set device and dtype before calling `add_saas_prior`,
|
220 | 224 | # since the SAAS prior contains tensors that are not parameters of the model, and
|
221 |
| - # terefore not automatically moved to the correct device with a `to` call on the |
| 225 | + # therefore not automatically moved to the correct device with a `to` call on the |
222 | 226 | # model.
|
223 | 227 | base_kernel.to(train_X)
|
224 | 228 | add_saas_prior(base_kernel=base_kernel, tau=tau)
|
@@ -421,3 +425,139 @@ def __init__(
|
421 | 425 | )
|
422 | 426 | # Make sure that all buffers and parameters have the correct device and dtype
|
423 | 427 | self.to(dtype=train_X.dtype, device=train_X.device)
|
| 428 | + |
| 429 | + |
| 430 | +class EnsembleMapSaasGP(SingleTaskGP): |
| 431 | + _is_ensemble = True |
| 432 | + |
| 433 | + def __init__( |
| 434 | + self, |
| 435 | + train_X: Tensor, |
| 436 | + train_Y: Tensor, |
| 437 | + train_Yvar: Tensor | None = None, |
| 438 | + num_taus: int = 4, |
| 439 | + taus: Tensor | None = None, |
| 440 | + outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT, |
| 441 | + input_transform: InputTransform | None = None, |
| 442 | + ) -> None: |
| 443 | + """Instantiates an ``EnsembleMapSaasGP``, which is a batched ensemble of |
| 444 | + ``SingleTaskGP``s with the Matern-5/2 kernel and a SAAS prior. The model is |
| 445 | + intended to be trained with ``ExactMarginalLogLikelihood`` and |
| 446 | + ``fit_gpytorch_mll``. Under the hood, the model is equivalent to a |
| 447 | + multi-output ``BatchedMultiOutputGPyTorchModel``, but it produces a |
| 448 | + ``MixtureGaussiaPosterior``, which leads to ensembling of the model outputs. |
| 449 | +
|
| 450 | + Args: |
| 451 | + train_X: An `n x d` tensor of training features. |
| 452 | + train_Y: An `n x 1` tensor of training observations. |
| 453 | + train_Yvar: An optional `n x 1` tensor of observed measurement noise. |
| 454 | + num_taus: The number of taus to use (4 if omitted). Each tau is |
| 455 | + a sparsity parameter for the corresponding kernel in the ensemble. |
| 456 | + taus: An optional tensor of shape `num_taus` containing the taus to use. |
| 457 | + If omitted, the taus are sampled from a HalfCauchy(0.1) distribution. |
| 458 | + outcome_transform: An outcome transform that is applied to the |
| 459 | + training data during instantiation and to the posterior during |
| 460 | + inference (that is, the `Posterior` obtained by calling |
| 461 | + `.posterior` on the model will be on the original scale). We use a |
| 462 | + `Standardize` transform if no `outcome_transform` is specified. |
| 463 | + Pass down `None` to use no outcome transform. Note that `.train()` will |
| 464 | + be called on the outcome transform during instantiation of the model. |
| 465 | + input_transform: An input transform that is applied in the model's |
| 466 | + forward pass. |
| 467 | + """ |
| 468 | + if taus is None: |
| 469 | + taus = HalfCauchy(torch.tensor(0.1)).sample([num_taus]).to(train_X) |
| 470 | + elif taus.shape != torch.Size([num_taus]): |
| 471 | + raise ValueError( |
| 472 | + f"Expected taus to be of shape {[num_taus]}. Got {taus.shape=}." |
| 473 | + ) |
| 474 | + if train_Y.shape[-1] != 1: |
| 475 | + raise UnsupportedError( |
| 476 | + f"EnsembleMapSAASGP only supports single-output. Got {train_Y.shape=}." |
| 477 | + ) |
| 478 | + if train_X.ndim != 2: |
| 479 | + raise UnsupportedError( |
| 480 | + f"EnsembleMapSAASGP only supports 2D inputs. Got {train_X.ndim=}." |
| 481 | + ) |
| 482 | + # Add batch dimension for ensemble. |
| 483 | + train_X = train_X.repeat(num_taus, 1, 1) |
| 484 | + train_Y = train_Y.repeat(num_taus, 1, 1) |
| 485 | + if train_Yvar is not None: |
| 486 | + train_Yvar = train_Yvar.repeat(num_taus, 1, 1) |
| 487 | + # Construct the sub-modules. |
| 488 | + if input_transform is not None: |
| 489 | + with torch.no_grad(): |
| 490 | + transformed_X = input_transform(train_X) |
| 491 | + ard_num_dims = transformed_X.shape[-1] |
| 492 | + else: |
| 493 | + ard_num_dims = train_X.shape[-1] |
| 494 | + batch_shape = train_X.shape[:-2] # This is torch.Size([num_taus]). |
| 495 | + mean_module = get_mean_module_with_normal_prior(batch_shape=batch_shape) |
| 496 | + base_kernel = MaternKernel( |
| 497 | + nu=2.5, ard_num_dims=ard_num_dims, batch_shape=batch_shape |
| 498 | + ) |
| 499 | + # NOTE: need to call `to` to set device and dtype before calling |
| 500 | + # `add_saas_prior`, since the SAAS prior contains tensors that are not |
| 501 | + # parameters of the model, and therefore not automatically moved to the |
| 502 | + # correct device with a `to` call on the model. |
| 503 | + base_kernel.to(train_X) |
| 504 | + add_saas_prior(base_kernel=base_kernel, tau=taus) |
| 505 | + covar_module = ScaleKernel( |
| 506 | + base_kernel=base_kernel, |
| 507 | + outputscale_constraint=LogTransformedInterval(1e-2, 1e4, initial_value=10), |
| 508 | + batch_shape=batch_shape, |
| 509 | + ) |
| 510 | + if train_Yvar is None: |
| 511 | + likelihood = get_gaussian_likelihood_with_gamma_prior( |
| 512 | + batch_shape=batch_shape |
| 513 | + ) |
| 514 | + else: |
| 515 | + likelihood = None |
| 516 | + |
| 517 | + super().__init__( |
| 518 | + train_X=train_X, |
| 519 | + train_Y=train_Y, |
| 520 | + train_Yvar=train_Yvar, |
| 521 | + likelihood=likelihood, |
| 522 | + covar_module=covar_module, |
| 523 | + mean_module=mean_module, |
| 524 | + outcome_transform=outcome_transform, |
| 525 | + input_transform=input_transform, |
| 526 | + ) |
| 527 | + |
| 528 | + def posterior( |
| 529 | + self, |
| 530 | + X: Tensor, |
| 531 | + output_indices: list[int] | None = None, |
| 532 | + observation_noise: bool = False, |
| 533 | + posterior_transform: PosteriorTransform | None = None, |
| 534 | + **kwargs: Any, |
| 535 | + ) -> GaussianMixturePosterior: |
| 536 | + r"""Computes the posterior over model outputs at the provided points. |
| 537 | +
|
| 538 | + Args: |
| 539 | + X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension |
| 540 | + of the feature space and `q` is the number of points considered |
| 541 | + jointly. |
| 542 | + output_indices: A list of indices, corresponding to the outputs over |
| 543 | + which to compute the posterior (if the model is multi-output). |
| 544 | + Can be used to speed up computation if only a subset of the |
| 545 | + model's outputs are required for optimization. If omitted, |
| 546 | + computes the posterior over all model outputs. |
| 547 | + observation_noise: If True, add the observation noise from the |
| 548 | + likelihood to the posterior. If a Tensor, use it directly as the |
| 549 | + observation noise (must be of shape `(batch_shape) x q x m`). |
| 550 | + posterior_transform: An optional PosteriorTransform. |
| 551 | +
|
| 552 | + Returns: |
| 553 | + A `GaussianMixturePosterior` object. Includes observation noise |
| 554 | + if specified. |
| 555 | + """ |
| 556 | + posterior = super().posterior( |
| 557 | + X=X.unsqueeze(MCMC_DIM), |
| 558 | + output_indices=output_indices, |
| 559 | + observation_noise=observation_noise, |
| 560 | + posterior_transform=posterior_transform, |
| 561 | + **kwargs, |
| 562 | + ) |
| 563 | + return GaussianMixturePosterior(distribution=posterior.distribution) |
0 commit comments