Skip to content

Commit ab765b4

Browse files
committed
linting, whitespace removal, import structuring
1 parent 7d6b9d6 commit ab765b4

File tree

5 files changed

+302
-353
lines changed

5 files changed

+302
-353
lines changed

bayes_opt/acquisition.py

Lines changed: 110 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ def _serialize_random_state(self) -> dict | None:
7474
if self.random_state is not None:
7575
state = self.random_state.get_state()
7676
return {
77-
'bit_generator': state[0],
78-
'state': state[1].tolist(), # Convert numpy array to list
79-
'pos': state[2],
80-
'has_gauss': state[3],
81-
'cached_gaussian': state[4]
77+
"bit_generator": state[0],
78+
"state": state[1].tolist(), # Convert numpy array to list
79+
"pos": state[2],
80+
"has_gauss": state[3],
81+
"cached_gaussian": state[4],
8282
}
8383
return None
8484

@@ -88,11 +88,11 @@ def _deserialize_random_state(self, state_dict: dict | None) -> None:
8888
if self.random_state is None:
8989
self.random_state = RandomState()
9090
state = (
91-
state_dict['bit_generator'],
92-
np.array(state_dict['state'], dtype=np.uint32),
93-
state_dict['pos'],
94-
state_dict['has_gauss'],
95-
state_dict['cached_gaussian']
91+
state_dict["bit_generator"],
92+
np.array(state_dict["state"], dtype=np.uint32),
93+
state_dict["pos"],
94+
state_dict["has_gauss"],
95+
state_dict["cached_gaussian"],
9696
)
9797
self.random_state.set_state(state)
9898

@@ -102,24 +102,24 @@ def base_acq(self, *args: Any, **kwargs: Any) -> NDArray[Float]:
102102

103103
def get_acquisition_params(self) -> dict[str, Any]:
104104
"""Get the acquisition function parameters.
105-
105+
106106
Returns
107107
-------
108108
dict
109109
Dictionary containing the acquisition function parameters.
110110
All values must be JSON serializable.
111111
"""
112112
return {}
113-
114-
def set_acquisition_params(self, params: dict[str, Any]) -> None:
113+
114+
def set_acquisition_params(self, params: dict) -> None:
115115
"""Set the acquisition function parameters.
116-
116+
117117
Parameters
118118
----------
119119
params : dict
120120
Dictionary containing the acquisition function parameters.
121121
"""
122-
pass
122+
return {}
123123

124124
def suggest(
125125
self,
@@ -167,7 +167,7 @@ def suggest(
167167

168168
acq = self._get_acq(gp=gp, constraint=target_space.constraint)
169169
return self._acq_min(acq, target_space, n_random=n_random, n_l_bfgs_b=n_l_bfgs_b)
170-
170+
171171
def _fit_gp(self, gp: GaussianProcessRegressor, target_space: TargetSpace) -> None:
172172
# Sklearn's GP throws a large number of warnings at times, but
173173
# we don't really need to see them here.
@@ -501,16 +501,30 @@ def decay_exploration(self) -> None:
501501
self.exploration_decay_delay is None or self.exploration_decay_delay <= self.i
502502
):
503503
self.kappa = self.kappa * self.exploration_decay
504-
504+
505505
def get_acquisition_params(self) -> dict:
506+
"""Get the current acquisition function parameters.
507+
508+
Returns
509+
-------
510+
dict
511+
Dictionary containing the current acquisition function parameters.
512+
"""
506513
return {
507514
"kappa": self.kappa,
508515
"exploration_decay": self.exploration_decay,
509516
"exploration_decay_delay": self.exploration_decay_delay,
510-
"random_state": self._serialize_random_state()
517+
"random_state": self._serialize_random_state(),
511518
}
512519

513520
def set_acquisition_params(self, params: dict) -> None:
521+
"""Set the acquisition function parameters.
522+
523+
Parameters
524+
----------
525+
params : dict
526+
Dictionary containing the acquisition function parameters.
527+
"""
514528
self.kappa = params["kappa"]
515529
self.exploration_decay = params["exploration_decay"]
516530
self.exploration_decay_delay = params["exploration_decay_delay"]
@@ -648,17 +662,30 @@ def decay_exploration(self) -> None:
648662
self.exploration_decay_delay is None or self.exploration_decay_delay <= self.i
649663
):
650664
self.xi = self.xi * self.exploration_decay
651-
665+
652666
def get_acquisition_params(self) -> dict:
653-
"""Get the acquisition function parameters."""
667+
"""Get the current acquisition function parameters.
668+
669+
Returns
670+
-------
671+
dict
672+
Dictionary containing the current acquisition function parameters.
673+
"""
654674
return {
655675
"xi": self.xi,
656676
"exploration_decay": self.exploration_decay,
657677
"exploration_decay_delay": self.exploration_decay_delay,
658-
"random_state": self._serialize_random_state()
678+
"random_state": self._serialize_random_state(),
659679
}
660-
680+
661681
def set_acquisition_params(self, params: dict) -> None:
682+
"""Set the acquisition function parameters.
683+
684+
Parameters
685+
----------
686+
params : dict
687+
Dictionary containing the acquisition function parameters.
688+
"""
662689
self.xi = params["xi"]
663690
self.exploration_decay = params["exploration_decay"]
664691
self.exploration_decay_delay = params["exploration_decay_delay"]
@@ -804,16 +831,30 @@ def decay_exploration(self) -> None:
804831
self.exploration_decay_delay is None or self.exploration_decay_delay <= self.i
805832
):
806833
self.xi = self.xi * self.exploration_decay
807-
834+
808835
def get_acquisition_params(self) -> dict:
836+
"""Get the current acquisition function parameters.
837+
838+
Returns
839+
-------
840+
dict
841+
Dictionary containing the current acquisition function parameters.
842+
"""
809843
return {
810844
"xi": self.xi,
811845
"exploration_decay": self.exploration_decay,
812846
"exploration_decay_delay": self.exploration_decay_delay,
813-
"random_state": self._serialize_random_state()
847+
"random_state": self._serialize_random_state(),
814848
}
815-
849+
816850
def set_acquisition_params(self, params: dict) -> None:
851+
"""Set the acquisition function parameters.
852+
853+
Parameters
854+
----------
855+
params : dict
856+
Dictionary containing the acquisition function parameters.
857+
"""
817858
self.xi = params["xi"]
818859
self.exploration_decay = params["exploration_decay"]
819860
self.exploration_decay_delay = params["exploration_decay_delay"]
@@ -1008,18 +1049,32 @@ def suggest(
10081049
self.dummies.append(x_max)
10091050

10101051
return x_max
1011-
1052+
10121053
def get_acquisition_params(self) -> dict:
1054+
"""Get the current acquisition function parameters.
1055+
1056+
Returns
1057+
-------
1058+
dict
1059+
Dictionary containing the current acquisition function parameters.
1060+
"""
10131061
return {
10141062
"dummies": [dummy.tolist() for dummy in self.dummies],
10151063
"base_acquisition_params": self.base_acquisition.get_acquisition_params(),
10161064
"strategy": self.strategy,
10171065
"atol": self.atol,
10181066
"rtol": self.rtol,
1019-
"random_state": self._serialize_random_state()
1067+
"random_state": self._serialize_random_state(),
10201068
}
1021-
1069+
10221070
def set_acquisition_params(self, params: dict) -> None:
1071+
"""Set the acquisition function parameters.
1072+
1073+
Parameters
1074+
----------
1075+
params : dict
1076+
Dictionary containing the acquisition function parameters.
1077+
"""
10231078
self.dummies = [np.array(dummy) for dummy in params["dummies"]]
10241079
self.base_acquisition.set_acquisition_params(params["base_acquisition_params"])
10251080
self.strategy = params["strategy"]
@@ -1144,28 +1199,42 @@ def suggest(
11441199
self.previous_candidates = np.array(x_max)
11451200
idx = self._sample_idx_from_softmax_gains()
11461201
return x_max[idx]
1147-
1202+
11481203
def get_acquisition_params(self) -> dict:
1204+
"""Get the current acquisition function parameters.
1205+
1206+
Returns
1207+
-------
1208+
dict
1209+
Dictionary containing the current acquisition function parameters.
1210+
"""
11491211
return {
11501212
"base_acquisitions_params": [acq.get_acquisition_params() for acq in self.base_acquisitions],
11511213
"gains": self.gains.tolist(),
1152-
"previous_candidates": self.previous_candidates.tolist() if self.previous_candidates is not None else None,
1153-
"random_states": [acq._serialize_random_state() for acq in self.base_acquisitions] + [self._serialize_random_state()]
1214+
"previous_candidates": self.previous_candidates.tolist()
1215+
if self.previous_candidates is not None
1216+
else None,
1217+
"random_states": [acq._serialize_random_state() for acq in self.base_acquisitions]
1218+
+ [self._serialize_random_state()],
11541219
}
1155-
1220+
11561221
def set_acquisition_params(self, params: dict) -> None:
1222+
"""Set the acquisition function parameters.
1223+
1224+
Parameters
1225+
----------
1226+
params : dict
1227+
Dictionary containing the acquisition function parameters.
1228+
"""
11571229
for acq, acq_params, random_state in zip(
1158-
self.base_acquisitions,
1159-
params["base_acquisitions_params"],
1160-
params["random_states"][:-1]
1230+
self.base_acquisitions, params["base_acquisitions_params"], params["random_states"][:-1]
11611231
):
11621232
acq.set_acquisition_params(acq_params)
11631233
acq._deserialize_random_state(random_state)
1164-
1234+
11651235
self.gains = np.array(params["gains"])
1166-
self.previous_candidates = (np.array(params["previous_candidates"])
1167-
if params["previous_candidates"] is not None
1168-
else None)
1169-
1170-
self._deserialize_random_state(params["random_states"][-1])
1236+
self.previous_candidates = (
1237+
np.array(params["previous_candidates"]) if params["previous_candidates"] is not None else None
1238+
)
11711239

1240+
self._deserialize_random_state(params["random_states"][-1])

0 commit comments

Comments
 (0)