diff --git a/botorch_community/acquisition/__init__.py b/botorch_community/acquisition/__init__.py index 8f78e713e6..79eea60cf9 100644 --- a/botorch_community/acquisition/__init__.py +++ b/botorch_community/acquisition/__init__.py @@ -6,6 +6,7 @@ from botorch_community.acquisition.bayesian_active_learning import ( qBayesianQueryByComittee, qBayesianVarianceReduction, + qHyperparameterInformedPredictiveExploration, qStatisticalDistanceActiveLearning, ) @@ -23,6 +24,7 @@ "LogRegionalExpectedImprovement", "qBayesianQueryByComittee", "qBayesianVarianceReduction", + "qHyperparameterInformedPredictiveExploration", "qLogRegionalExpectedImprovement", "qSelfCorrectingBayesianOptimization", "qStatisticalDistanceActiveLearning", diff --git a/botorch_community/acquisition/bayesian_active_learning.py b/botorch_community/acquisition/bayesian_active_learning.py index 14df4c7193..3409f64e19 100644 --- a/botorch_community/acquisition/bayesian_active_learning.py +++ b/botorch_community/acquisition/bayesian_active_learning.py @@ -33,13 +33,21 @@ from __future__ import annotations +import math + from typing import Optional import torch +from botorch.acquisition.acquisition import MCSamplerMixin from botorch.acquisition.bayesian_active_learning import ( FullyBayesianAcquisitionFunction, + qBayesianActiveLearningByDisagreement, ) +from botorch.acquisition.objective import PosteriorTransform from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP +from botorch.optim import optimize_acqf +from botorch.sampling.base import MCSampler +from botorch.utils.sampling import draw_sobol_samples from botorch.utils.transforms import ( average_over_ensemble_models, concatenate_pending_points, @@ -159,3 +167,190 @@ def forward(self, X: Tensor) -> Tensor: # squeeze output dim - batch dim computed and reduced inside of dist # MCMC dim is averaged in decorator return dist.squeeze(-1) + + +class qExpectedPredictiveInformationGain(FullyBayesianAcquisitionFunction): + def __init__( + self, + model: SaasFullyBayesianSingleTaskGP, + mc_points: Tensor, + X_pending: Tensor | None = None, + ) -> None: + """Expected predictive information gain for active learning. + + Computes the mutual information between candidate queries and a test set + (typically MC samples over the design space). + + Args: + model: A fully bayesian model (SaasFullyBayesianSingleTaskGP). + mc_points: A `N x d` tensor of points to use for MC-integrating the + posterior entropy (test set). + X_pending: A `m x d`-dim Tensor of `m` design points. + """ + super().__init__(model) + if mc_points.ndim != 2: + raise ValueError( + f"mc_points must be a 2-dimensional tensor, but got shape " + f"{mc_points.shape}" + ) + self.register_buffer("mc_points", mc_points) + self.set_X_pending(X_pending) + + @concatenate_pending_points + @t_batch_mode_transform() + @average_over_ensemble_models + def forward(self, X: Tensor) -> Tensor: + """Evaluate test set information gain. + + Args: + X: A `batch_shape x q x d`-dim Tensor of input points. + average: Whether to average over the test set. + + Returns: + A Tensor of information gain values. + """ + # Get the posterior for the candidate points + posterior = self.model.posterior(X, observation_noise=True) + noise = ( + posterior.variance + - self.model.posterior(X, observation_noise=False).variance + ) + cond_Y = posterior.mean + + # Condition the model on the candidate observations + cond_X = X.unsqueeze(-3).expand(*[cond_Y.shape[:-1] + X.shape[-1:]]) + conditional_model = self.model.condition_on_observations( + X=cond_X, + Y=cond_Y, + noise=noise, + ) + + # Evaluate posterior variance at test set with and without conditioning + uncond_var = self.model.posterior( + self.mc_points, observation_noise=True + ).variance + cond_var = conditional_model.posterior( + self.mc_points, observation_noise=True + ).variance + + # Compute information gain as reduction in entropy + prev_entropy = torch.log(uncond_var * 2 * math.pi * math.exp(1.0)).sum(-1) / 2 + post_entropy = torch.log(cond_var * 2 * math.pi * math.exp(1.0)).sum(-1) / 2 + return (prev_entropy - post_entropy).mean(-1) + + +class qHyperparameterInformedPredictiveExploration( + FullyBayesianAcquisitionFunction, MCSamplerMixin +): + def __init__( + self, + model: SaasFullyBayesianSingleTaskGP, + mc_points: Tensor, + bounds: Tensor, + q: int, + sampler: MCSampler | None = None, + posterior_transform: PosteriorTransform | None = None, + X_pending: Tensor | None = None, + num_samples: int = 512, + beta: float | None = None, + beta_tuning_method: str = "sobol", + ) -> None: + """Hyperparameter-informed Predictive Exploration acquisition function. + + This acquisition function combines the mutual information between the + subsequent queries and a test set (predictive information gain) with the + mutual information between observations and hyperparameters (BALD), weighted + by a tuning factor. This balances exploration of the design space with + reduction of hyperparameter uncertainty. + + The acquisition function is computed as: + beta * BALD + TSIG + where beta is either provided or automatically tuned. + + Args: + model: A fully bayesian model (SaasFullyBayesianSingleTaskGP). + mc_points: A `N x d` tensor of points to use for MC-integrating the + posterior entropy (test set). Usually, these are qMC samples on + the whole design space. + bounds: A `2 x d` tensor of bounds for the design space, used for + beta tuning when beta is not provided. + sampler: The sampler used for drawing samples to approximate the entropy + of the Gaussian Mixture posterior. If None, uses default sampler. + X_pending: A `m x d`-dim Tensor of `m` design points that have been + submitted for evaluation but have not yet been observed. + num_samples: Number of samples to use for MC estimation of entropy. + q: Batch size to use for beta tuning optimization. + beta: Fixed tuning factor. If None, it will be automatically computed. + beta_tuning_method: Method for tuning beta. Options are "optimize" + (optimize acquisition function to find beta) or "sobol" (use sobol + samples). Only used when beta is None. + """ + super().__init__(model=model) + MCSamplerMixin.__init__(self) + self.set_X_pending(X_pending) + self.q = q + self.num_samples = num_samples + self.beta_tuning_method = beta_tuning_method + self.register_buffer("mc_points", mc_points) + self.register_buffer("bounds", bounds) + self.sampler = sampler + self.posterior_transform = posterior_transform + + if beta is None: + self.compute_tuning_factor(self.bounds, self.q) + else: + self.tuning_factor = beta + + def compute_tuning_factor(self, bounds: Tensor, q: int) -> None: + """Compute the tuning factor beta for weighting BALD vs TSIG.""" + if self.beta_tuning_method == "sobol": + draws = draw_sobol_samples( + bounds=bounds, + q=q, + n=1, + ).squeeze(0) + # Compute the ratio at sobol samples + tsig_val = qExpectedPredictiveInformationGain.forward( + self, + draws, + ) + bald_val = qBayesianActiveLearningByDisagreement.forward(self, draws) + self.tuning_factor = (tsig_val / (bald_val + 1e-8)).mean().item() + elif self.beta_tuning_method == "optimize": + # Optimize to find the best tuning factor + bald_acqf = qBayesianActiveLearningByDisagreement( + model=self.model, + sampler=self.sampler, + ) + _, bald_val = optimize_acqf( + bald_acqf, + bounds=bounds, + q=q, + num_restarts=1, + raw_samples=128, + options={"batch_limit": 16}, + ) + self.tuning_factor = bald_val.mean().item() + else: + raise ValueError( + f"beta_tuning_method must be 'sobol' or 'optimize', " + f"got {self.beta_tuning_method}" + ) + + @concatenate_pending_points + @t_batch_mode_transform() + def forward(self, X: Tensor) -> Tensor: + """Evaluate the acquisition function at X. + + Args: + X: A `batch_shape x q x d`-dim Tensor of input points. + average: Whether to average over the test set (for TSIG). + + Returns: + A `batch_shape`-dim Tensor of acquisition values. + """ + tsig = qExpectedPredictiveInformationGain.forward(self, X) + bald = qBayesianActiveLearningByDisagreement.forward(self, X) + # Since both acquisition functions are averaged over the ensemble, + # we do not average over the ensemble again here. + return self.tuning_factor * bald + tsig diff --git a/test_community/acquisition/test_bayesian_active_learning.py b/test_community/acquisition/test_bayesian_active_learning.py index 8fed6d6a6f..7adeab06f1 100644 --- a/test_community/acquisition/test_bayesian_active_learning.py +++ b/test_community/acquisition/test_bayesian_active_learning.py @@ -7,11 +7,13 @@ from itertools import product import torch +from botorch.utils.sampling import draw_sobol_samples from botorch.utils.test_helpers import get_fully_bayesian_model from botorch.utils.testing import BotorchTestCase from botorch_community.acquisition.bayesian_active_learning import ( qBayesianQueryByComittee, qBayesianVarianceReduction, + qHyperparameterInformedPredictiveExploration, qStatisticalDistanceActiveLearning, ) @@ -72,13 +74,59 @@ def test_q_statistical_distance_active_learning(self): # assess shape self.assertTrue(acq_X.shape == test_Xs[j].shape[:-2]) - with self.assertRaises(ValueError): - acq = qStatisticalDistanceActiveLearning( + +class TestQHyperparameterInformedPredictiveExploration(BotorchTestCase): + def test_q_hyperparameter_informed_predictive_exploration(self): + torch.manual_seed(1) + tkwargs = {"device": self.device} + num_objectives = 1 + num_models = 3 + for ( + dtype, + standardize_model, + infer_noise, + ) in product( + (torch.float, torch.double), + (False, True), + (True,), + ): + tkwargs["dtype"] = dtype + input_dim = 2 + train_X = torch.rand(4, input_dim, **tkwargs) + train_Y = torch.rand(4, num_objectives, **tkwargs) + + model = get_fully_bayesian_model( + train_X=train_X, + train_Y=train_Y, + num_models=num_models, + standardize_model=standardize_model, + infer_noise=infer_noise, + **tkwargs, + ) + + bounds = torch.tensor([[0.0] * input_dim, [1.0] * input_dim], **tkwargs) + mc_points = draw_sobol_samples(bounds=bounds, n=16, q=1).squeeze(-2) + + # test with fixed beta + acq = qHyperparameterInformedPredictiveExploration( model=model, - distance_metric="NOT_A_DISTANCE", - X_pending=X_pending, + mc_points=mc_points, + bounds=bounds, + beta=1.0, + q=2, ) + test_Xs = [ + torch.rand(4, 1, input_dim, **tkwargs), + torch.rand(4, 3, input_dim, **tkwargs), + ] + + for test_X in test_Xs: + acq_X = acq(test_X) + # assess shape + self.assertTrue(acq_X.shape == test_X.shape[:-2]) + self.assertTrue((acq_X > 0).all()) + class TestQBayesianQueryByComittee(BotorchTestCase): def test_q_bayesian_query_by_comittee(self):