Skip to content

Commit 0cd1186

Browse files
author
Manda Kausthubh
committed
Added basic SAASBO
1 parent 2c78f7c commit 0cd1186

File tree

1 file changed

+234
-0
lines changed

1 file changed

+234
-0
lines changed

bayes_opt/bayesian_optimization.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
from typing import TYPE_CHECKING, Any
1414
from warnings import warn
1515

16+
import torch
17+
import pyro
1618
import numpy as np
1719
from scipy.optimize import NonlinearConstraint
1820
from sklearn.gaussian_process import GaussianProcessRegressor
1921
from sklearn.gaussian_process.kernels import Matern
22+
import pyro.distributions as dist
23+
from pyro.infer.mcmc import NUTS, MCMC
2024

2125
from bayes_opt import acquisition
2226
from bayes_opt.constraint import ConstraintModel
@@ -442,3 +446,233 @@ def load_state(self, path: str | PathLike[str]) -> None:
442446
state["random_state"]["cached_gaussian"],
443447
)
444448
self._random_state.set_state(random_state_tuple)
449+
450+
451+
class SAASBO(BayesianOptimization):
452+
"""Sparsity-Aware Acquisition for Bayesian Optimization (SAASBO).
453+
454+
This class extends BayesianOptimization to implement SAASBO, which uses a
455+
Gaussian Process with a horseshoe prior on the kernel length scales to promote
456+
sparsity in high-dimensional optimization problems. It uses MCMC for fully
457+
Bayesian inference over the GP hyperparameters.
458+
459+
Additional Parameters
460+
--------------------
461+
num_samples: int, optional (default=500)
462+
Number of MCMC samples to draw from the GP posterior.
463+
warmup_steps: int, optional (default=500)
464+
Number of warmup steps for MCMC sampling.
465+
thinning: int, optional (default=16)
466+
Thinning factor for MCMC samples to reduce autocorrelation.
467+
"""
468+
469+
def __init__(
470+
self,
471+
f: Callable[..., float] | None,
472+
pbounds: Mapping[str, tuple[float, float]],
473+
acquisition_function: AcquisitionFunction | None = None,
474+
constraint: Optional[NonlinearConstraint] = None,
475+
random_state: int | RandomState | None = None,
476+
verbose: int = 2,
477+
bounds_transformer: Optional[DomainTransformer] = None,
478+
allow_duplicate_points: bool = False,
479+
num_samples: int = 500,
480+
warmup_steps: int = 500,
481+
thinning: int = 16,
482+
):
483+
# Initialize the parent class
484+
super().__init__(
485+
f=f,
486+
pbounds=pbounds,
487+
acquisition_function=acquisition_function,
488+
constraint=constraint,
489+
random_state=random_state,
490+
verbose=verbose,
491+
bounds_transformer=bounds_transformer,
492+
allow_duplicate_points=allow_duplicate_points,
493+
)
494+
495+
# SAASBO-specific parameters
496+
self.num_samples = num_samples
497+
self.warmup_steps = warmup_steps
498+
self.thinning = thinning
499+
self._random_state = ensure_rng(random_state)
500+
501+
# Override the default acquisition function to Expected Improvement if not specified
502+
if acquisition_function is None:
503+
self._acquisition_function = acquisition.ExpectedImprovement(
504+
xi=0.01, random_state=self._random_state
505+
)
506+
507+
# Remove the default GP regressor, as we'll use a Pyro-based GP
508+
self._gp = None
509+
self._mcmc_samples = None
510+
511+
def _define_gp_model(self, X: torch.Tensor, y: torch.Tensor) -> Callable:
512+
"""Define the Pyro GP model with a horseshoe prior on length scales."""
513+
def gp_model(X: torch.Tensor, y: torch.Tensor):
514+
# Kernel hyperparameters
515+
outputscale = pyro.sample("outputscale", dist.LogNormal(0.0, 1.0))
516+
noise = pyro.sample("noise", dist.LogNormal(-2.0, 1.0))
517+
518+
# Horseshoe prior on length scales for each dimension
519+
dim = X.shape[1]
520+
tau = pyro.sample("tau", dist.HalfCauchy(0.1))
521+
beta = pyro.sample("beta", dist.HalfCauchy(torch.ones(dim)))
522+
lengthscale = tau * beta
523+
524+
# Matern 5/2 kernel with horseshoe length scales
525+
kernel = pyro.contrib.gp.kernels.Matern52(
526+
input_dim=dim,
527+
lengthscale=lengthscale,
528+
variance=outputscale,
529+
)
530+
531+
# Define the GP
532+
gpr = pyro.contrib.gp.models.GPRegression(
533+
X=X,
534+
y=y,
535+
kernel=kernel,
536+
noise=noise,
537+
)
538+
539+
# Sample the mean
540+
mean = pyro.sample("mean", dist.Normal(0.0, 1.0))
541+
gpr.mean = mean
542+
return gpr
543+
544+
return gp_model
545+
546+
def _fit_gp(self) -> None:
547+
"""Fit the GP model using MCMC to sample from the posterior."""
548+
if len(self._space) == 0:
549+
return
550+
551+
# Convert data to PyTorch tensors
552+
X = torch.tensor(self._space.params, dtype=torch.float64)
553+
y = torch.tensor(self._space.target, dtype=torch.float64)
554+
555+
# Define the GP model
556+
gp_model = self._define_gp_model(X, y)
557+
558+
# Set up MCMC with NUTS
559+
nuts_kernel = NUTS(gp_model)
560+
mcmc = MCMC(
561+
kernel=nuts_kernel,
562+
num_samples=self.num_samples,
563+
warmup_steps=self.warmup_steps,
564+
num_chains=1,
565+
)
566+
567+
# Run MCMC
568+
mcmc.run(X, y)
569+
570+
# Get samples
571+
self._mcmc_samples = mcmc.get_samples()
572+
573+
def suggest(self) -> dict[str, float | np.ndarray]:
574+
"""Suggest a promising point to probe next using SAASBO.
575+
576+
This method averages the acquisition function over MCMC samples of the GP.
577+
"""
578+
if len(self._space) == 0:
579+
return self._space.array_to_params(self._space.random_sample(random_state=self._random_state))
580+
581+
# Fit the GP model with MCMC if not already done
582+
if self._mcmc_samples is None:
583+
self._fit_gp()
584+
585+
# Generate candidate points (e.g., using random sampling or a grid)
586+
n_candidates = 1000
587+
candidates = self._space.random_sample(n_candidates, random_state=self._random_state)
588+
candidates_tensor = torch.tensor(candidates, dtype=torch.float64)
589+
590+
# Initialize acquisition values
591+
acq_values = torch.zeros(n_candidates, dtype=torch.float64)
592+
593+
# Average acquisition function over MCMC samples
594+
for i in range(0, self.num_samples, self.thinning):
595+
# Extract hyperparameters for this sample
596+
outputscale = self._mcmc_samples["outputscale"][i]
597+
noise = self._mcmc_samples["noise"][i]
598+
lengthscale = self._mcmc_samples["lengthscale"][i]
599+
mean = self._mcmc_samples["mean"][i]
600+
601+
# Define the GP model for this sample
602+
kernel = pyro.contrib.gp.kernels.Matern52(
603+
input_dim=candidates_tensor.shape[1],
604+
lengthscale=lengthscale,
605+
variance=outputscale,
606+
)
607+
gp = pyro.contrib.gp.models.GPRegression(
608+
X=torch.tensor(self._space.params, dtype=torch.float64),
609+
y=torch.tensor(self._space.target, dtype=torch.float64),
610+
kernel=kernel,
611+
noise=noise,
612+
mean=mean,
613+
)
614+
615+
# Compute acquisition function for candidates
616+
acq = self._acquisition_function.evaluate(
617+
candidates=candidates_tensor,
618+
gp=gp,
619+
target_space=self._space,
620+
)
621+
acq_values += acq / (self.num_samples // self.thinning)
622+
623+
# Select the candidate with the highest acquisition value
624+
best_idx = torch.argmax(acq_values)
625+
suggestion = candidates[best_idx]
626+
627+
return self._space.array_to_params(suggestion)
628+
629+
def maximize(self, init_points: int = 5, n_iter: int = 25) -> None:
630+
"""Maximize the target function using SAASBO.
631+
632+
Parameters
633+
----------
634+
init_points: int, optional (default=5)
635+
Number of random points to probe before starting the optimization.
636+
n_iter: int, optional (default=25)
637+
Number of iterations to perform.
638+
"""
639+
self.logger.log_optimization_start(self._space.keys)
640+
self._prime_queue(init_points)
641+
642+
iteration = 0
643+
while self._queue or iteration < n_iter:
644+
try:
645+
x_probe = self._queue.popleft()
646+
except IndexError:
647+
x_probe = self.suggest()
648+
iteration += 1
649+
self.probe(x_probe, lazy=False)
650+
651+
# Refit the GP after each new observation
652+
self._fit_gp()
653+
654+
if self._bounds_transformer and iteration > 0:
655+
self.set_bounds(self._bounds_transformer.transform(self._space))
656+
657+
self.logger.log_optimization_end()
658+
659+
def set_gp_params(self, **params: Any) -> None:
660+
"""Set parameters for the SAASBO GP model.
661+
662+
Parameters
663+
----------
664+
num_samples: int, optional
665+
Number of MCMC samples.
666+
warmup_steps: int, optional
667+
Number of warmup steps for MCMC.
668+
thinning: int, optional
669+
Thinning factor for MCMC samples.
670+
"""
671+
if "num_samples" in params:
672+
self.num_samples = params.pop("num_samples")
673+
if "warmup_steps" in params:
674+
self.warmup_steps = params.pop("warmup_steps")
675+
if "thinning" in params:
676+
self.thinning = params.pop("thinning")
677+
if params:
678+
self.logger.warning(f"Ignored unknown parameters: {params}")

0 commit comments

Comments
 (0)