Skip to content

Commit 3d3e538

Browse files
committed
remove duplicate acquisition functions random state
1 parent a68d727 commit 3d3e538

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

bayes_opt/acquisition.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,8 +1221,7 @@ def get_acquisition_params(self) -> dict:
12211221
"previous_candidates": self.previous_candidates.tolist()
12221222
if self.previous_candidates is not None
12231223
else None,
1224-
"random_states": [acq._serialize_random_state() for acq in self.base_acquisitions]
1225-
+ [self._serialize_random_state()],
1224+
"gphedge_random_state": self._serialize_random_state(),
12261225
}
12271226

12281227
def set_acquisition_params(self, params: dict) -> None:
@@ -1233,15 +1232,14 @@ def set_acquisition_params(self, params: dict) -> None:
12331232
params : dict
12341233
Dictionary containing the acquisition function parameters.
12351234
"""
1236-
for acq, acq_params, random_state in zip(
1237-
self.base_acquisitions, params["base_acquisitions_params"], params["random_states"][:-1]
1235+
for acq, acq_params in zip(
1236+
self.base_acquisitions, params["base_acquisitions_params"]
12381237
):
12391238
acq.set_acquisition_params(acq_params)
1240-
acq._deserialize_random_state(random_state)
12411239

12421240
self.gains = np.array(params["gains"])
12431241
self.previous_candidates = (
12441242
np.array(params["previous_candidates"]) if params["previous_candidates"] is not None else None
12451243
)
12461244

1247-
self._deserialize_random_state(params["random_states"][-1])
1245+
self._deserialize_random_state(params["gphedge_random_state"])

tests/test_acquisition.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,3 +596,45 @@ def test_integration_constrained(target_func_x_and_y, pbounds, constraint, tmp_p
596596
new_optimizer.load_state(state_path)
597597

598598
verify_optimizers_match(optimizer, new_optimizer)
599+
600+
601+
def test_custom_acquisition_without_get_params():
602+
"""Test that a custom acquisition function without get_acquisition_params raises NotImplementedError."""
603+
604+
class CustomAcqWithoutGetParams(acquisition.AcquisitionFunction):
605+
def __init__(self, random_state=None):
606+
super().__init__(random_state=random_state)
607+
608+
def base_acq(self, mean, std):
609+
return mean + std
610+
611+
def set_acquisition_params(self, params):
612+
pass
613+
614+
acq = CustomAcqWithoutGetParams()
615+
with pytest.raises(
616+
NotImplementedError,
617+
match="Custom AcquisitionFunction subclasses must implement their own get_acquisition_params method",
618+
):
619+
acq.get_acquisition_params()
620+
621+
622+
def test_custom_acquisition_without_set_params():
623+
"""Test that a custom acquisition function without set_acquisition_params raises NotImplementedError."""
624+
625+
class CustomAcqWithoutSetParams(acquisition.AcquisitionFunction):
626+
def __init__(self, random_state=None):
627+
super().__init__(random_state=random_state)
628+
629+
def base_acq(self, mean, std):
630+
return mean + std
631+
632+
def get_acquisition_params(self):
633+
return {}
634+
635+
acq = CustomAcqWithoutSetParams()
636+
with pytest.raises(
637+
NotImplementedError,
638+
match="Custom AcquisitionFunction subclasses must implement their own set_acquisition_params method",
639+
):
640+
acq.set_acquisition_params(params={})

0 commit comments

Comments
 (0)