Skip to content

Commit 2b514aa

Browse files
committed
add the random state to the set of things to list of saved items
1 parent 7be3854 commit 2b514aa

File tree

1 file changed

+143
-9
lines changed

1 file changed

+143
-9
lines changed

bayes_opt/acquisition.py

Lines changed: 143 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,57 @@ def __init__(self, random_state: int | RandomState | None = None) -> None:
6969
self.random_state = RandomState()
7070
self.i = 0
7171

72+
def _serialize_random_state(self) -> dict | None:
73+
"""Convert random state to JSON serializable format."""
74+
if self.random_state is not None:
75+
state = self.random_state.get_state()
76+
return {
77+
'bit_generator': state[0],
78+
'state': state[1].tolist(), # Convert numpy array to list
79+
'pos': state[2],
80+
'has_gauss': state[3],
81+
'cached_gaussian': state[4]
82+
}
83+
return None
84+
85+
def _deserialize_random_state(self, state_dict: dict | None) -> None:
86+
"""Restore random state from JSON serializable format."""
87+
if state_dict is not None:
88+
if self.random_state is None:
89+
self.random_state = RandomState()
90+
state = (
91+
state_dict['bit_generator'],
92+
np.array(state_dict['state'], dtype=np.uint32),
93+
state_dict['pos'],
94+
state_dict['has_gauss'],
95+
state_dict['cached_gaussian']
96+
)
97+
self.random_state.set_state(state)
98+
7299
@abc.abstractmethod
73100
def base_acq(self, *args: Any, **kwargs: Any) -> NDArray[Float]:
74101
"""Provide access to the base acquisition function."""
75-
76-
def _fit_gp(self, gp: GaussianProcessRegressor, target_space: TargetSpace) -> None:
77-
# Sklearn's GP throws a large number of warnings at times, but
78-
# we don't really need to see them here.
79-
with warnings.catch_warnings():
80-
warnings.simplefilter("ignore")
81-
gp.fit(target_space.params, target_space.target)
82-
if target_space.constraint is not None:
83-
target_space.constraint.fit(target_space.params, target_space._constraint_values)
102+
103+
@abc.abstractmethod
104+
def get_acquisition_params(self) -> dict[str, Any]:
105+
"""Get the acquisition function parameters.
106+
107+
Returns
108+
-------
109+
dict
110+
Dictionary containing the acquisition function parameters.
111+
All values must be JSON serializable.
112+
"""
113+
114+
@abc.abstractmethod
115+
def set_acquisition_params(self, params: dict[str, Any]) -> None:
116+
"""Set the acquisition function parameters.
117+
118+
Parameters
119+
----------
120+
params : dict
121+
Dictionary containing the acquisition function parameters.
122+
"""
84123

85124
def suggest(
86125
self,
@@ -128,6 +167,15 @@ def suggest(
128167

129168
acq = self._get_acq(gp=gp, constraint=target_space.constraint)
130169
return self._acq_min(acq, target_space, n_random=n_random, n_l_bfgs_b=n_l_bfgs_b)
170+
171+
def _fit_gp(self, gp: GaussianProcessRegressor, target_space: TargetSpace) -> None:
172+
# Sklearn's GP throws a large number of warnings at times, but
173+
# we don't really need to see them here.
174+
with warnings.catch_warnings():
175+
warnings.simplefilter("ignore")
176+
gp.fit(target_space.params, target_space.target)
177+
if target_space.constraint is not None:
178+
target_space.constraint.fit(target_space.params, target_space._constraint_values)
131179

132180
def _get_acq(
133181
self, gp: GaussianProcessRegressor, constraint: ConstraintModel | None = None
@@ -453,6 +501,20 @@ def decay_exploration(self) -> None:
453501
self.exploration_decay_delay is None or self.exploration_decay_delay <= self.i
454502
):
455503
self.kappa = self.kappa * self.exploration_decay
504+
505+
def get_acquisition_params(self) -> dict:
506+
return {
507+
"kappa": self.kappa,
508+
"exploration_decay": self.exploration_decay,
509+
"exploration_decay_delay": self.exploration_decay_delay,
510+
"random_state": self._serialize_random_state()
511+
}
512+
513+
def set_acquisition_params(self, params: dict) -> None:
514+
self.kappa = params["kappa"]
515+
self.exploration_decay = params["exploration_decay"]
516+
self.exploration_decay_delay = params["exploration_decay_delay"]
517+
self._deserialize_random_state(params["random_state"])
456518

457519

458520
class ProbabilityOfImprovement(AcquisitionFunction):
@@ -586,6 +648,21 @@ def decay_exploration(self) -> None:
586648
self.exploration_decay_delay is None or self.exploration_decay_delay <= self.i
587649
):
588650
self.xi = self.xi * self.exploration_decay
651+
652+
def get_acquisition_params(self) -> dict:
653+
"""Get the acquisition function parameters."""
654+
return {
655+
"xi": self.xi,
656+
"exploration_decay": self.exploration_decay,
657+
"exploration_decay_delay": self.exploration_decay_delay,
658+
"random_state": self._serialize_random_state()
659+
}
660+
661+
def set_acquisition_params(self, params: dict) -> None:
662+
self.xi = params["xi"]
663+
self.exploration_decay = params["exploration_decay"]
664+
self.exploration_decay_delay = params["exploration_decay_delay"]
665+
self._deserialize_random_state(params["random_state"])
589666

590667

591668
class ExpectedImprovement(AcquisitionFunction):
@@ -727,6 +804,20 @@ def decay_exploration(self) -> None:
727804
self.exploration_decay_delay is None or self.exploration_decay_delay <= self.i
728805
):
729806
self.xi = self.xi * self.exploration_decay
807+
808+
def get_acquisition_params(self) -> dict:
809+
return {
810+
"xi": self.xi,
811+
"exploration_decay": self.exploration_decay,
812+
"exploration_decay_delay": self.exploration_decay_delay,
813+
"random_state": self._serialize_random_state()
814+
}
815+
816+
def set_acquisition_params(self, params: dict) -> None:
817+
self.xi = params["xi"]
818+
self.exploration_decay = params["exploration_decay"]
819+
self.exploration_decay_delay = params["exploration_decay_delay"]
820+
self._deserialize_random_state(params["random_state"])
730821

731822

732823
class ConstantLiar(AcquisitionFunction):
@@ -917,6 +1008,24 @@ def suggest(
9171008
self.dummies.append(x_max)
9181009

9191010
return x_max
1011+
1012+
def get_acquisition_params(self) -> dict:
1013+
return {
1014+
"dummies": [dummy.tolist() for dummy in self.dummies],
1015+
"base_acquisition_params": self.base_acquisition.get_acquisition_params(),
1016+
"strategy": self.strategy,
1017+
"atol": self.atol,
1018+
"rtol": self.rtol,
1019+
"random_state": self._serialize_random_state()
1020+
}
1021+
1022+
def set_acquisition_params(self, params: dict) -> None:
1023+
self.dummies = [np.array(dummy) for dummy in params["dummies"]]
1024+
self.base_acquisition.set_acquisition_params(params["base_acquisition_params"])
1025+
self.strategy = params["strategy"]
1026+
self.atol = params["atol"]
1027+
self.rtol = params["rtol"]
1028+
self._deserialize_random_state(params["random_state"])
9201029

9211030

9221031
class GPHedge(AcquisitionFunction):
@@ -1035,3 +1144,28 @@ def suggest(
10351144
self.previous_candidates = np.array(x_max)
10361145
idx = self._sample_idx_from_softmax_gains()
10371146
return x_max[idx]
1147+
1148+
def get_acquisition_params(self) -> dict:
1149+
return {
1150+
"base_acquisitions_params": [acq.get_acquisition_params() for acq in self.base_acquisitions],
1151+
"gains": self.gains.tolist(),
1152+
"previous_candidates": self.previous_candidates.tolist() if self.previous_candidates is not None else None,
1153+
"random_states": [acq._serialize_random_state() for acq in self.base_acquisitions] + [self._serialize_random_state()]
1154+
}
1155+
1156+
def set_acquisition_params(self, params: dict) -> None:
1157+
for acq, acq_params, random_state in zip(
1158+
self.base_acquisitions,
1159+
params["base_acquisitions_params"],
1160+
params["random_states"][:-1]
1161+
):
1162+
acq.set_acquisition_params(acq_params)
1163+
acq._deserialize_random_state(random_state)
1164+
1165+
self.gains = np.array(params["gains"])
1166+
self.previous_candidates = (np.array(params["previous_candidates"])
1167+
if params["previous_candidates"] is not None
1168+
else None)
1169+
1170+
self._deserialize_random_state(params["random_states"][-1])
1171+

0 commit comments

Comments
 (0)