Skip to content

Commit 0b6148f

Browse files
authored
Merge branch 'bayesian-optimization:master' into master
2 parents 0e859df + 1714504 commit 0b6148f

27 files changed

+1686
-1092
lines changed

.github/workflows/build_docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ concurrency:
1313

1414
jobs:
1515
build-docs-and-publish:
16-
runs-on: ubuntu-20.04
16+
runs-on: ubuntu-latest
1717
permissions:
1818
contents: write
1919
steps:

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,8 @@ docsrc/*.ipynb
3636
docsrc/static/*
3737
docsrc/README.md
3838

39-
poetry.lock
39+
poetry.lock
40+
41+
# Add log files and optimizer state files to gitignore
42+
examples/logs.log
43+
examples/optimizer_state.json

bayes_opt/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import importlib.metadata
66

77
from bayes_opt import acquisition
8-
from bayes_opt.bayesian_optimization import BayesianOptimization, Events
8+
from bayes_opt.bayesian_optimization import BayesianOptimization
99
from bayes_opt.constraint import ConstraintModel
1010
from bayes_opt.domain_reduction import SequentialDomainReductionTransformer
11-
from bayes_opt.logger import JSONLogger, ScreenLogger
11+
from bayes_opt.logger import ScreenLogger
1212
from bayes_opt.target_space import TargetSpace
1313

1414
__version__ = importlib.metadata.version("bayesian-optimization")
@@ -19,8 +19,6 @@
1919
"BayesianOptimization",
2020
"TargetSpace",
2121
"ConstraintModel",
22-
"Events",
2322
"ScreenLogger",
24-
"JSONLogger",
2523
"SequentialDomainReductionTransformer",
2624
]

bayes_opt/acquisition.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,33 @@ def __init__(self, random_state: int | RandomState | None = None) -> None:
6969
self.random_state = RandomState()
7070
self.i = 0
7171

72+
def _serialize_random_state(self) -> dict | None:
73+
"""Convert random state to JSON serializable format."""
74+
if self.random_state is not None:
75+
state = self.random_state.get_state()
76+
return {
77+
"bit_generator": state[0],
78+
"state": state[1].tolist(), # Convert numpy array to list
79+
"pos": state[2],
80+
"has_gauss": state[3],
81+
"cached_gaussian": state[4],
82+
}
83+
return None
84+
85+
def _deserialize_random_state(self, state_dict: dict | None) -> None:
86+
"""Restore random state from JSON serializable format."""
87+
if state_dict is not None:
88+
if self.random_state is None:
89+
self.random_state = RandomState()
90+
state = (
91+
state_dict["bit_generator"],
92+
np.array(state_dict["state"], dtype=np.uint32),
93+
state_dict["pos"],
94+
state_dict["has_gauss"],
95+
state_dict["cached_gaussian"],
96+
)
97+
self.random_state.set_state(state)
98+
7299
@abc.abstractmethod
73100
def base_acq(self, *args: Any, **kwargs: Any) -> NDArray[Float]:
74101
"""Provide access to the base acquisition function."""
@@ -82,6 +109,34 @@ def _fit_gp(self, gp: GaussianProcessRegressor, target_space: TargetSpace) -> No
82109
if target_space.constraint is not None:
83110
target_space.constraint.fit(target_space.params, target_space._constraint_values)
84111

112+
def get_acquisition_params(self) -> dict[str, Any]:
113+
"""
114+
Get the parameters of the acquisition function.
115+
116+
Returns
117+
-------
118+
dict
119+
The parameters of the acquisition function.
120+
"""
121+
error_msg = (
122+
"Custom AcquisitionFunction subclasses must implement their own get_acquisition_params method."
123+
)
124+
raise NotImplementedError(error_msg)
125+
126+
def set_acquisition_params(self, **params) -> None:
127+
"""
128+
Set the parameters of the acquisition function.
129+
130+
Parameters
131+
----------
132+
**params : dict
133+
The parameters of the acquisition function.
134+
"""
135+
error_msg = (
136+
"Custom AcquisitionFunction subclasses must implement their own set_acquisition_params method."
137+
)
138+
raise NotImplementedError(error_msg)
139+
85140
def suggest(
86141
self,
87142
gp: GaussianProcessRegressor,
@@ -462,6 +517,34 @@ def decay_exploration(self) -> None:
462517
):
463518
self.kappa = self.kappa * self.exploration_decay
464519

520+
def get_acquisition_params(self) -> dict:
521+
"""Get the current acquisition function parameters.
522+
523+
Returns
524+
-------
525+
dict
526+
Dictionary containing the current acquisition function parameters.
527+
"""
528+
return {
529+
"kappa": self.kappa,
530+
"exploration_decay": self.exploration_decay,
531+
"exploration_decay_delay": self.exploration_decay_delay,
532+
"random_state": self._serialize_random_state(),
533+
}
534+
535+
def set_acquisition_params(self, params: dict) -> None:
536+
"""Set the acquisition function parameters.
537+
538+
Parameters
539+
----------
540+
params : dict
541+
Dictionary containing the acquisition function parameters.
542+
"""
543+
self.kappa = params["kappa"]
544+
self.exploration_decay = params["exploration_decay"]
545+
self.exploration_decay_delay = params["exploration_decay_delay"]
546+
self._deserialize_random_state(params["random_state"])
547+
465548

466549
class ProbabilityOfImprovement(AcquisitionFunction):
467550
r"""Probability of Improvement acqusition function.
@@ -595,6 +678,34 @@ def decay_exploration(self) -> None:
595678
):
596679
self.xi = self.xi * self.exploration_decay
597680

681+
def get_acquisition_params(self) -> dict:
682+
"""Get the current acquisition function parameters.
683+
684+
Returns
685+
-------
686+
dict
687+
Dictionary containing the current acquisition function parameters.
688+
"""
689+
return {
690+
"xi": self.xi,
691+
"exploration_decay": self.exploration_decay,
692+
"exploration_decay_delay": self.exploration_decay_delay,
693+
"random_state": self._serialize_random_state(),
694+
}
695+
696+
def set_acquisition_params(self, params: dict) -> None:
697+
"""Set the acquisition function parameters.
698+
699+
Parameters
700+
----------
701+
params : dict
702+
Dictionary containing the acquisition function parameters.
703+
"""
704+
self.xi = params["xi"]
705+
self.exploration_decay = params["exploration_decay"]
706+
self.exploration_decay_delay = params["exploration_decay_delay"]
707+
self._deserialize_random_state(params["random_state"])
708+
598709

599710
class ExpectedImprovement(AcquisitionFunction):
600711
r"""Expected Improvement acqusition function.
@@ -736,6 +847,34 @@ def decay_exploration(self) -> None:
736847
):
737848
self.xi = self.xi * self.exploration_decay
738849

850+
def get_acquisition_params(self) -> dict:
851+
"""Get the current acquisition function parameters.
852+
853+
Returns
854+
-------
855+
dict
856+
Dictionary containing the current acquisition function parameters.
857+
"""
858+
return {
859+
"xi": self.xi,
860+
"exploration_decay": self.exploration_decay,
861+
"exploration_decay_delay": self.exploration_decay_delay,
862+
"random_state": self._serialize_random_state(),
863+
}
864+
865+
def set_acquisition_params(self, params: dict) -> None:
866+
"""Set the acquisition function parameters.
867+
868+
Parameters
869+
----------
870+
params : dict
871+
Dictionary containing the acquisition function parameters.
872+
"""
873+
self.xi = params["xi"]
874+
self.exploration_decay = params["exploration_decay"]
875+
self.exploration_decay_delay = params["exploration_decay_delay"]
876+
self._deserialize_random_state(params["random_state"])
877+
739878

740879
class ConstantLiar(AcquisitionFunction):
741880
"""Constant Liar acquisition function.
@@ -926,6 +1065,38 @@ def suggest(
9261065

9271066
return x_max
9281067

1068+
def get_acquisition_params(self) -> dict:
1069+
"""Get the current acquisition function parameters.
1070+
1071+
Returns
1072+
-------
1073+
dict
1074+
Dictionary containing the current acquisition function parameters.
1075+
"""
1076+
return {
1077+
"dummies": [dummy.tolist() for dummy in self.dummies],
1078+
"base_acquisition_params": self.base_acquisition.get_acquisition_params(),
1079+
"strategy": self.strategy,
1080+
"atol": self.atol,
1081+
"rtol": self.rtol,
1082+
"random_state": self._serialize_random_state(),
1083+
}
1084+
1085+
def set_acquisition_params(self, params: dict) -> None:
1086+
"""Set the acquisition function parameters.
1087+
1088+
Parameters
1089+
----------
1090+
params : dict
1091+
Dictionary containing the acquisition function parameters.
1092+
"""
1093+
self.dummies = [np.array(dummy) for dummy in params["dummies"]]
1094+
self.base_acquisition.set_acquisition_params(params["base_acquisition_params"])
1095+
self.strategy = params["strategy"]
1096+
self.atol = params["atol"]
1097+
self.rtol = params["rtol"]
1098+
self._deserialize_random_state(params["random_state"])
1099+
9291100

9301101
class GPHedge(AcquisitionFunction):
9311102
"""GPHedge acquisition function.
@@ -1043,3 +1214,38 @@ def suggest(
10431214
self.previous_candidates = np.array(x_max)
10441215
idx = self._sample_idx_from_softmax_gains()
10451216
return x_max[idx]
1217+
1218+
def get_acquisition_params(self) -> dict:
1219+
"""Get the current acquisition function parameters.
1220+
1221+
Returns
1222+
-------
1223+
dict
1224+
Dictionary containing the current acquisition function parameters.
1225+
"""
1226+
return {
1227+
"base_acquisitions_params": [acq.get_acquisition_params() for acq in self.base_acquisitions],
1228+
"gains": self.gains.tolist(),
1229+
"previous_candidates": self.previous_candidates.tolist()
1230+
if self.previous_candidates is not None
1231+
else None,
1232+
"gphedge_random_state": self._serialize_random_state(),
1233+
}
1234+
1235+
def set_acquisition_params(self, params: dict) -> None:
1236+
"""Set the acquisition function parameters.
1237+
1238+
Parameters
1239+
----------
1240+
params : dict
1241+
Dictionary containing the acquisition function parameters.
1242+
"""
1243+
for acq, acq_params in zip(self.base_acquisitions, params["base_acquisitions_params"]):
1244+
acq.set_acquisition_params(acq_params)
1245+
1246+
self.gains = np.array(params["gains"])
1247+
self.previous_candidates = (
1248+
np.array(params["previous_candidates"]) if params["previous_candidates"] is not None else None
1249+
)
1250+
1251+
self._deserialize_random_state(params["gphedge_random_state"])

0 commit comments

Comments
 (0)