@@ -74,11 +74,11 @@ def _serialize_random_state(self) -> dict | None:
7474 if self .random_state is not None :
7575 state = self .random_state .get_state ()
7676 return {
77- ' bit_generator' : state [0 ],
78- ' state' : state [1 ].tolist (), # Convert numpy array to list
79- ' pos' : state [2 ],
80- ' has_gauss' : state [3 ],
81- ' cached_gaussian' : state [4 ]
77+ " bit_generator" : state [0 ],
78+ " state" : state [1 ].tolist (), # Convert numpy array to list
79+ " pos" : state [2 ],
80+ " has_gauss" : state [3 ],
81+ " cached_gaussian" : state [4 ],
8282 }
8383 return None
8484
@@ -88,11 +88,11 @@ def _deserialize_random_state(self, state_dict: dict | None) -> None:
8888 if self .random_state is None :
8989 self .random_state = RandomState ()
9090 state = (
91- state_dict [' bit_generator' ],
92- np .array (state_dict [' state' ], dtype = np .uint32 ),
93- state_dict [' pos' ],
94- state_dict [' has_gauss' ],
95- state_dict [' cached_gaussian' ]
91+ state_dict [" bit_generator" ],
92+ np .array (state_dict [" state" ], dtype = np .uint32 ),
93+ state_dict [" pos" ],
94+ state_dict [" has_gauss" ],
95+ state_dict [" cached_gaussian" ],
9696 )
9797 self .random_state .set_state (state )
9898
@@ -102,24 +102,24 @@ def base_acq(self, *args: Any, **kwargs: Any) -> NDArray[Float]:
102102
103103 def get_acquisition_params (self ) -> dict [str , Any ]:
104104 """Get the acquisition function parameters.
105-
105+
106106 Returns
107107 -------
108108 dict
109109 Dictionary containing the acquisition function parameters.
110110 All values must be JSON serializable.
111111 """
112112 return {}
113-
114- def set_acquisition_params (self , params : dict [ str , Any ] ) -> None :
113+
114+ def set_acquisition_params (self , params : dict ) -> None :
115115 """Set the acquisition function parameters.
116-
116+
117117 Parameters
118118 ----------
119119 params : dict
120120 Dictionary containing the acquisition function parameters.
121121 """
122- pass
122+ return {}
123123
124124 def suggest (
125125 self ,
@@ -167,7 +167,7 @@ def suggest(
167167
168168 acq = self ._get_acq (gp = gp , constraint = target_space .constraint )
169169 return self ._acq_min (acq , target_space , n_random = n_random , n_l_bfgs_b = n_l_bfgs_b )
170-
170+
171171 def _fit_gp (self , gp : GaussianProcessRegressor , target_space : TargetSpace ) -> None :
172172 # Sklearn's GP throws a large number of warnings at times, but
173173 # we don't really need to see them here.
@@ -501,16 +501,30 @@ def decay_exploration(self) -> None:
501501 self .exploration_decay_delay is None or self .exploration_decay_delay <= self .i
502502 ):
503503 self .kappa = self .kappa * self .exploration_decay
504-
504+
505505 def get_acquisition_params (self ) -> dict :
506+ """Get the current acquisition function parameters.
507+
508+ Returns
509+ -------
510+ dict
511+ Dictionary containing the current acquisition function parameters.
512+ """
506513 return {
507514 "kappa" : self .kappa ,
508515 "exploration_decay" : self .exploration_decay ,
509516 "exploration_decay_delay" : self .exploration_decay_delay ,
510- "random_state" : self ._serialize_random_state ()
517+ "random_state" : self ._serialize_random_state (),
511518 }
512519
513520 def set_acquisition_params (self , params : dict ) -> None :
521+ """Set the acquisition function parameters.
522+
523+ Parameters
524+ ----------
525+ params : dict
526+ Dictionary containing the acquisition function parameters.
527+ """
514528 self .kappa = params ["kappa" ]
515529 self .exploration_decay = params ["exploration_decay" ]
516530 self .exploration_decay_delay = params ["exploration_decay_delay" ]
@@ -648,17 +662,30 @@ def decay_exploration(self) -> None:
648662 self .exploration_decay_delay is None or self .exploration_decay_delay <= self .i
649663 ):
650664 self .xi = self .xi * self .exploration_decay
651-
665+
652666 def get_acquisition_params (self ) -> dict :
653- """Get the acquisition function parameters."""
667+ """Get the current acquisition function parameters.
668+
669+ Returns
670+ -------
671+ dict
672+ Dictionary containing the current acquisition function parameters.
673+ """
654674 return {
655675 "xi" : self .xi ,
656676 "exploration_decay" : self .exploration_decay ,
657677 "exploration_decay_delay" : self .exploration_decay_delay ,
658- "random_state" : self ._serialize_random_state ()
678+ "random_state" : self ._serialize_random_state (),
659679 }
660-
680+
661681 def set_acquisition_params (self , params : dict ) -> None :
682+ """Set the acquisition function parameters.
683+
684+ Parameters
685+ ----------
686+ params : dict
687+ Dictionary containing the acquisition function parameters.
688+ """
662689 self .xi = params ["xi" ]
663690 self .exploration_decay = params ["exploration_decay" ]
664691 self .exploration_decay_delay = params ["exploration_decay_delay" ]
@@ -804,16 +831,30 @@ def decay_exploration(self) -> None:
804831 self .exploration_decay_delay is None or self .exploration_decay_delay <= self .i
805832 ):
806833 self .xi = self .xi * self .exploration_decay
807-
834+
808835 def get_acquisition_params (self ) -> dict :
836+ """Get the current acquisition function parameters.
837+
838+ Returns
839+ -------
840+ dict
841+ Dictionary containing the current acquisition function parameters.
842+ """
809843 return {
810844 "xi" : self .xi ,
811845 "exploration_decay" : self .exploration_decay ,
812846 "exploration_decay_delay" : self .exploration_decay_delay ,
813- "random_state" : self ._serialize_random_state ()
847+ "random_state" : self ._serialize_random_state (),
814848 }
815-
849+
816850 def set_acquisition_params (self , params : dict ) -> None :
851+ """Set the acquisition function parameters.
852+
853+ Parameters
854+ ----------
855+ params : dict
856+ Dictionary containing the acquisition function parameters.
857+ """
817858 self .xi = params ["xi" ]
818859 self .exploration_decay = params ["exploration_decay" ]
819860 self .exploration_decay_delay = params ["exploration_decay_delay" ]
@@ -1008,18 +1049,32 @@ def suggest(
10081049 self .dummies .append (x_max )
10091050
10101051 return x_max
1011-
1052+
10121053 def get_acquisition_params (self ) -> dict :
1054+ """Get the current acquisition function parameters.
1055+
1056+ Returns
1057+ -------
1058+ dict
1059+ Dictionary containing the current acquisition function parameters.
1060+ """
10131061 return {
10141062 "dummies" : [dummy .tolist () for dummy in self .dummies ],
10151063 "base_acquisition_params" : self .base_acquisition .get_acquisition_params (),
10161064 "strategy" : self .strategy ,
10171065 "atol" : self .atol ,
10181066 "rtol" : self .rtol ,
1019- "random_state" : self ._serialize_random_state ()
1067+ "random_state" : self ._serialize_random_state (),
10201068 }
1021-
1069+
10221070 def set_acquisition_params (self , params : dict ) -> None :
1071+ """Set the acquisition function parameters.
1072+
1073+ Parameters
1074+ ----------
1075+ params : dict
1076+ Dictionary containing the acquisition function parameters.
1077+ """
10231078 self .dummies = [np .array (dummy ) for dummy in params ["dummies" ]]
10241079 self .base_acquisition .set_acquisition_params (params ["base_acquisition_params" ])
10251080 self .strategy = params ["strategy" ]
@@ -1144,28 +1199,42 @@ def suggest(
11441199 self .previous_candidates = np .array (x_max )
11451200 idx = self ._sample_idx_from_softmax_gains ()
11461201 return x_max [idx ]
1147-
1202+
11481203 def get_acquisition_params (self ) -> dict :
1204+ """Get the current acquisition function parameters.
1205+
1206+ Returns
1207+ -------
1208+ dict
1209+ Dictionary containing the current acquisition function parameters.
1210+ """
11491211 return {
11501212 "base_acquisitions_params" : [acq .get_acquisition_params () for acq in self .base_acquisitions ],
11511213 "gains" : self .gains .tolist (),
1152- "previous_candidates" : self .previous_candidates .tolist () if self .previous_candidates is not None else None ,
1153- "random_states" : [acq ._serialize_random_state () for acq in self .base_acquisitions ] + [self ._serialize_random_state ()]
1214+ "previous_candidates" : self .previous_candidates .tolist ()
1215+ if self .previous_candidates is not None
1216+ else None ,
1217+ "random_states" : [acq ._serialize_random_state () for acq in self .base_acquisitions ]
1218+ + [self ._serialize_random_state ()],
11541219 }
1155-
1220+
11561221 def set_acquisition_params (self , params : dict ) -> None :
1222+ """Set the acquisition function parameters.
1223+
1224+ Parameters
1225+ ----------
1226+ params : dict
1227+ Dictionary containing the acquisition function parameters.
1228+ """
11571229 for acq , acq_params , random_state in zip (
1158- self .base_acquisitions ,
1159- params ["base_acquisitions_params" ],
1160- params ["random_states" ][:- 1 ]
1230+ self .base_acquisitions , params ["base_acquisitions_params" ], params ["random_states" ][:- 1 ]
11611231 ):
11621232 acq .set_acquisition_params (acq_params )
11631233 acq ._deserialize_random_state (random_state )
1164-
1234+
11651235 self .gains = np .array (params ["gains" ])
1166- self .previous_candidates = (np .array (params ["previous_candidates" ])
1167- if params ["previous_candidates" ] is not None
1168- else None )
1169-
1170- self ._deserialize_random_state (params ["random_states" ][- 1 ])
1236+ self .previous_candidates = (
1237+ np .array (params ["previous_candidates" ]) if params ["previous_candidates" ] is not None else None
1238+ )
11711239
1240+ self ._deserialize_random_state (params ["random_states" ][- 1 ])
0 commit comments