Skip to content

Commit 4aab0b8

Browse files
committed
reorganize state saving and loading for consistency
1 parent c26504e commit 4aab0b8

File tree

1 file changed

+52
-24
lines changed

1 file changed

+52
-24
lines changed

bayes_opt/bayesian_optimization.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,18 @@ def save_state(self, path: str | PathLike[str]) -> None:
369369
----------
370370
path : str or PathLike
371371
Path to save the optimization state
372+
373+
Raises
374+
------
375+
ValueError
376+
If attempting to save state before collecting any samples.
372377
"""
378+
if len(self._space) == 0:
379+
raise ValueError(
380+
"Cannot save optimizer state before collecting any samples. "
381+
"Please probe or register at least one point before saving."
382+
)
383+
373384
random_state = None
374385
if self._random_state is not None:
375386
state_tuple = self._random_state.get_state()
@@ -391,8 +402,14 @@ def save_state(self, path: str | PathLike[str]) -> None:
391402
key: self._space._bounds[i].tolist()
392403
for i, key in enumerate(self._space.keys)
393404
},
405+
# Add current transformed bounds if using bounds transformer
406+
"transformed_bounds": (
407+
self._space.bounds.tolist()
408+
if self._bounds_transformer
409+
else None
410+
),
394411
"keys": self._space.keys,
395-
"params": self._space.params.tolist(),
412+
"params": np.array(self._space.params).tolist(),
396413
"target": self._space.target.tolist(),
397414
"constraint_values": constraint_values,
398415
"gp_params": {
@@ -414,28 +431,8 @@ def load_state(self, path: str | PathLike[str]) -> None:
414431
with Path(path).open('r') as file:
415432
state = json.load(file)
416433

417-
if state["random_state"] is not None:
418-
random_state_tuple = (
419-
state["random_state"]["bit_generator"],
420-
np.array(state["random_state"]["state"], dtype=np.uint32),
421-
state["random_state"]["pos"],
422-
state["random_state"]["has_gauss"],
423-
state["random_state"]["cached_gaussian"],
424-
)
425-
self._random_state.set_state(random_state_tuple)
426-
427-
self._gp.set_params(**state["gp_params"])
428-
429-
if isinstance(self._gp.kernel, dict):
430-
kernel_params = self._gp.kernel
431-
self._gp.kernel = Matern(
432-
length_scale=kernel_params['length_scale'],
433-
length_scale_bounds=tuple(kernel_params['length_scale_bounds']),
434-
nu=kernel_params['nu']
435-
)
436-
437-
params_array = np.array(state["params"])
438-
target_array = np.array(state["target"])
434+
params_array = np.asarray(state["params"], dtype=np.float64)
435+
target_array = np.asarray(state["target"], dtype=np.float64)
439436
constraint_array = (np.array(state["constraint_values"])
440437
if state["constraint_values"] is not None
441438
else None)
@@ -449,6 +446,37 @@ def load_state(self, path: str | PathLike[str]) -> None:
449446
target=target,
450447
constraint_value=constraint
451448
)
452-
449+
453450
self._acquisition_function.set_acquisition_params(state["acquisition_params"])
451+
452+
if state.get("transformed_bounds") and self._bounds_transformer:
453+
new_bounds = {
454+
key: bounds for key, bounds in zip(
455+
self._space.keys,
456+
np.array(state["transformed_bounds"])
457+
)
458+
}
459+
self._space.set_bounds(new_bounds)
460+
self._bounds_transformer.initialize(self._space)
461+
462+
self._gp.set_params(**state["gp_params"])
463+
if isinstance(self._gp.kernel, dict):
464+
kernel_params = self._gp.kernel
465+
self._gp.kernel = Matern(
466+
length_scale=kernel_params['length_scale'],
467+
length_scale_bounds=tuple(kernel_params['length_scale_bounds']),
468+
nu=kernel_params['nu']
469+
)
470+
454471
self._gp.fit(self._space.params, self._space.target)
472+
473+
if state["random_state"] is not None:
474+
random_state_tuple = (
475+
state["random_state"]["bit_generator"],
476+
np.array(state["random_state"]["state"], dtype=np.uint32),
477+
state["random_state"]["pos"],
478+
state["random_state"]["has_gauss"],
479+
state["random_state"]["cached_gaussian"],
480+
)
481+
self._random_state.set_state(random_state_tuple)
482+

0 commit comments

Comments
 (0)