Skip to content

Commit be73262

Browse files
committed
move state loading to separate function, add functionality for saving acquisition function state
1 parent 2b514aa commit be73262

File tree

1 file changed

+46
-55
lines changed

1 file changed

+46
-55
lines changed

bayes_opt/bayesian_optimization.py

Lines changed: 46 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,6 @@ class BayesianOptimization(Observable):
107107
This behavior may be desired in high noise situations where repeatedly probing
108108
the same point will give different answers. In other situations, the acquisition
109109
may occasionally generate a duplicate point.
110-
111-
load_state_path: str | Path | None, optional (default=None)
112-
If provided, load optimizer state from this path instead of initializing fresh
113110
"""
114111

115112
def __init__(
@@ -122,7 +119,6 @@ def __init__(
122119
verbose: int = 2,
123120
bounds_transformer: DomainTransformer | None = None,
124121
allow_duplicate_points: bool = False,
125-
load_state_path: str | Path | None = None,
126122
):
127123
self._random_state = ensure_rng(random_state)
128124
self._allow_duplicate_points = allow_duplicate_points
@@ -181,55 +177,6 @@ def __init__(
181177
self._sorting_warning_already_shown = False # TODO: remove in future version
182178
super().__init__(events=DEFAULT_EVENTS)
183179

184-
if load_state_path is not None:
185-
with Path(load_state_path).open('r') as file:
186-
state = json.load(file)
187-
self._set_state_from_dict(state)
188-
189-
def _set_state_from_dict(self, state: dict[str, Any]) -> None:
190-
"""Set optimizer state from a dictionary of saved values."""
191-
if state["random_state"] is not None:
192-
random_state_tuple = (
193-
state["random_state"]["bit_generator"],
194-
np.array(state["random_state"]["state"], dtype=np.uint32),
195-
state["random_state"]["pos"],
196-
state["random_state"]["has_gauss"],
197-
state["random_state"]["cached_gaussian"],
198-
)
199-
self._random_state.set_state(random_state_tuple)
200-
201-
self._gp.set_params(**state["gp_params"])
202-
203-
# Handle kernel separately since it needs reconstruction
204-
if isinstance(self._gp.kernel, dict):
205-
kernel_params = self._gp.kernel
206-
self._gp.kernel = Matern(
207-
length_scale=kernel_params['length_scale'],
208-
length_scale_bounds=kernel_params['length_scale_bounds'],
209-
nu=kernel_params['nu']
210-
)
211-
212-
# Register previous points
213-
params_array = np.array(state["params"])
214-
target_array = np.array(state["target"])
215-
constraint_array = (np.array(state["constraint_values"])
216-
if state["constraint_values"] is not None
217-
else None)
218-
219-
for i in range(len(params_array)):
220-
params = self._space.array_to_params(params_array[i])
221-
target = target_array[i]
222-
constraint = constraint_array[i] if constraint_array is not None else None
223-
self.register(
224-
params=params,
225-
target=target,
226-
constraint_value=constraint,
227-
)
228-
229-
# Fit GP if there are samples
230-
if len(self._space) > 0:
231-
self._gp.fit(self._space.params, self._space.target)
232-
233180
@property
234181
def space(self) -> TargetSpace:
235182
"""Return the target space associated with the optimizer."""
@@ -433,12 +380,12 @@ def save_state(self, path: str | PathLike[str]) -> None:
433380
'has_gauss': state_tuple[3],
434381
'cached_gaussian': state_tuple[4],
435382
}
436-
383+
437384
# Get constraint values if they exist
438385
constraint_values = (self._space._constraint_values.tolist()
439386
if self.is_constrained
440387
else None)
441-
388+
acquisition_params = self._acquisition_function.get_acquisition_params()
442389
state = {
443390
"pbounds": {
444391
key: self._space._bounds[i].tolist()
@@ -457,7 +404,51 @@ def save_state(self, path: str | PathLike[str]) -> None:
457404
"allow_duplicate_points": self._allow_duplicate_points,
458405
"verbose": self._verbose,
459406
"random_state": random_state,
407+
"acquisition_params": acquisition_params,
460408
}
461409

462410
with Path(path).open('w') as f:
463411
json.dump(state, f, indent=2)
412+
413+
def load_state(self, path: str | PathLike[str]) -> None:
414+
with Path(path).open('r') as file:
415+
state = json.load(file)
416+
417+
if state["random_state"] is not None:
418+
random_state_tuple = (
419+
state["random_state"]["bit_generator"],
420+
np.array(state["random_state"]["state"], dtype=np.uint32),
421+
state["random_state"]["pos"],
422+
state["random_state"]["has_gauss"],
423+
state["random_state"]["cached_gaussian"],
424+
)
425+
self._random_state.set_state(random_state_tuple)
426+
427+
self._gp.set_params(**state["gp_params"])
428+
429+
if isinstance(self._gp.kernel, dict):
430+
kernel_params = self._gp.kernel
431+
self._gp.kernel = Matern(
432+
length_scale=kernel_params['length_scale'],
433+
length_scale_bounds=tuple(kernel_params['length_scale_bounds']),
434+
nu=kernel_params['nu']
435+
)
436+
437+
params_array = np.array(state["params"])
438+
target_array = np.array(state["target"])
439+
constraint_array = (np.array(state["constraint_values"])
440+
if state["constraint_values"] is not None
441+
else None)
442+
443+
for i in range(len(params_array)):
444+
params = self._space.array_to_params(params_array[i])
445+
target = target_array[i]
446+
constraint = constraint_array[i] if constraint_array is not None else None
447+
self.register(
448+
params=params,
449+
target=target,
450+
constraint_value=constraint
451+
)
452+
453+
self._acquisition_function.set_acquisition_params(state["acquisition_params"])
454+
self._gp.fit(self._space.params, self._space.target)

0 commit comments

Comments
 (0)