@@ -69,18 +69,57 @@ def __init__(self, random_state: int | RandomState | None = None) -> None:
6969 self .random_state = RandomState ()
7070 self .i = 0
7171
72+ def _serialize_random_state (self ) -> dict | None :
73+ """Convert random state to JSON serializable format."""
74+ if self .random_state is not None :
75+ state = self .random_state .get_state ()
76+ return {
77+ 'bit_generator' : state [0 ],
78+ 'state' : state [1 ].tolist (), # Convert numpy array to list
79+ 'pos' : state [2 ],
80+ 'has_gauss' : state [3 ],
81+ 'cached_gaussian' : state [4 ]
82+ }
83+ return None
84+
85+ def _deserialize_random_state (self , state_dict : dict | None ) -> None :
86+ """Restore random state from JSON serializable format."""
87+ if state_dict is not None :
88+ if self .random_state is None :
89+ self .random_state = RandomState ()
90+ state = (
91+ state_dict ['bit_generator' ],
92+ np .array (state_dict ['state' ], dtype = np .uint32 ),
93+ state_dict ['pos' ],
94+ state_dict ['has_gauss' ],
95+ state_dict ['cached_gaussian' ]
96+ )
97+ self .random_state .set_state (state )
98+
7299 @abc .abstractmethod
73100 def base_acq (self , * args : Any , ** kwargs : Any ) -> NDArray [Float ]:
74101 """Provide access to the base acquisition function."""
75-
76- def _fit_gp (self , gp : GaussianProcessRegressor , target_space : TargetSpace ) -> None :
77- # Sklearn's GP throws a large number of warnings at times, but
78- # we don't really need to see them here.
79- with warnings .catch_warnings ():
80- warnings .simplefilter ("ignore" )
81- gp .fit (target_space .params , target_space .target )
82- if target_space .constraint is not None :
83- target_space .constraint .fit (target_space .params , target_space ._constraint_values )
102+
103+ @abc .abstractmethod
104+ def get_acquisition_params (self ) -> dict [str , Any ]:
105+ """Get the acquisition function parameters.
106+
107+ Returns
108+ -------
109+ dict
110+ Dictionary containing the acquisition function parameters.
111+ All values must be JSON serializable.
112+ """
113+
114+ @abc .abstractmethod
115+ def set_acquisition_params (self , params : dict [str , Any ]) -> None :
116+ """Set the acquisition function parameters.
117+
118+ Parameters
119+ ----------
120+ params : dict
121+ Dictionary containing the acquisition function parameters.
122+ """
84123
85124 def suggest (
86125 self ,
@@ -128,6 +167,15 @@ def suggest(
128167
129168 acq = self ._get_acq (gp = gp , constraint = target_space .constraint )
130169 return self ._acq_min (acq , target_space , n_random = n_random , n_l_bfgs_b = n_l_bfgs_b )
170+
171+ def _fit_gp (self , gp : GaussianProcessRegressor , target_space : TargetSpace ) -> None :
172+ # Sklearn's GP throws a large number of warnings at times, but
173+ # we don't really need to see them here.
174+ with warnings .catch_warnings ():
175+ warnings .simplefilter ("ignore" )
176+ gp .fit (target_space .params , target_space .target )
177+ if target_space .constraint is not None :
178+ target_space .constraint .fit (target_space .params , target_space ._constraint_values )
131179
132180 def _get_acq (
133181 self , gp : GaussianProcessRegressor , constraint : ConstraintModel | None = None
@@ -453,6 +501,20 @@ def decay_exploration(self) -> None:
453501 self .exploration_decay_delay is None or self .exploration_decay_delay <= self .i
454502 ):
455503 self .kappa = self .kappa * self .exploration_decay
504+
505+ def get_acquisition_params (self ) -> dict :
506+ return {
507+ "kappa" : self .kappa ,
508+ "exploration_decay" : self .exploration_decay ,
509+ "exploration_decay_delay" : self .exploration_decay_delay ,
510+ "random_state" : self ._serialize_random_state ()
511+ }
512+
513+ def set_acquisition_params (self , params : dict ) -> None :
514+ self .kappa = params ["kappa" ]
515+ self .exploration_decay = params ["exploration_decay" ]
516+ self .exploration_decay_delay = params ["exploration_decay_delay" ]
517+ self ._deserialize_random_state (params ["random_state" ])
456518
457519
458520class ProbabilityOfImprovement (AcquisitionFunction ):
@@ -586,6 +648,21 @@ def decay_exploration(self) -> None:
586648 self .exploration_decay_delay is None or self .exploration_decay_delay <= self .i
587649 ):
588650 self .xi = self .xi * self .exploration_decay
651+
652+ def get_acquisition_params (self ) -> dict :
653+ """Get the acquisition function parameters."""
654+ return {
655+ "xi" : self .xi ,
656+ "exploration_decay" : self .exploration_decay ,
657+ "exploration_decay_delay" : self .exploration_decay_delay ,
658+ "random_state" : self ._serialize_random_state ()
659+ }
660+
661+ def set_acquisition_params (self , params : dict ) -> None :
662+ self .xi = params ["xi" ]
663+ self .exploration_decay = params ["exploration_decay" ]
664+ self .exploration_decay_delay = params ["exploration_decay_delay" ]
665+ self ._deserialize_random_state (params ["random_state" ])
589666
590667
591668class ExpectedImprovement (AcquisitionFunction ):
@@ -727,6 +804,20 @@ def decay_exploration(self) -> None:
727804 self .exploration_decay_delay is None or self .exploration_decay_delay <= self .i
728805 ):
729806 self .xi = self .xi * self .exploration_decay
807+
808+ def get_acquisition_params (self ) -> dict :
809+ return {
810+ "xi" : self .xi ,
811+ "exploration_decay" : self .exploration_decay ,
812+ "exploration_decay_delay" : self .exploration_decay_delay ,
813+ "random_state" : self ._serialize_random_state ()
814+ }
815+
816+ def set_acquisition_params (self , params : dict ) -> None :
817+ self .xi = params ["xi" ]
818+ self .exploration_decay = params ["exploration_decay" ]
819+ self .exploration_decay_delay = params ["exploration_decay_delay" ]
820+ self ._deserialize_random_state (params ["random_state" ])
730821
731822
732823class ConstantLiar (AcquisitionFunction ):
@@ -917,6 +1008,24 @@ def suggest(
9171008 self .dummies .append (x_max )
9181009
9191010 return x_max
1011+
1012+ def get_acquisition_params (self ) -> dict :
1013+ return {
1014+ "dummies" : [dummy .tolist () for dummy in self .dummies ],
1015+ "base_acquisition_params" : self .base_acquisition .get_acquisition_params (),
1016+ "strategy" : self .strategy ,
1017+ "atol" : self .atol ,
1018+ "rtol" : self .rtol ,
1019+ "random_state" : self ._serialize_random_state ()
1020+ }
1021+
1022+ def set_acquisition_params (self , params : dict ) -> None :
1023+ self .dummies = [np .array (dummy ) for dummy in params ["dummies" ]]
1024+ self .base_acquisition .set_acquisition_params (params ["base_acquisition_params" ])
1025+ self .strategy = params ["strategy" ]
1026+ self .atol = params ["atol" ]
1027+ self .rtol = params ["rtol" ]
1028+ self ._deserialize_random_state (params ["random_state" ])
9201029
9211030
9221031class GPHedge (AcquisitionFunction ):
@@ -1035,3 +1144,28 @@ def suggest(
10351144 self .previous_candidates = np .array (x_max )
10361145 idx = self ._sample_idx_from_softmax_gains ()
10371146 return x_max [idx ]
1147+
1148+ def get_acquisition_params (self ) -> dict :
1149+ return {
1150+ "base_acquisitions_params" : [acq .get_acquisition_params () for acq in self .base_acquisitions ],
1151+ "gains" : self .gains .tolist (),
1152+ "previous_candidates" : self .previous_candidates .tolist () if self .previous_candidates is not None else None ,
1153+ "random_states" : [acq ._serialize_random_state () for acq in self .base_acquisitions ] + [self ._serialize_random_state ()]
1154+ }
1155+
1156+ def set_acquisition_params (self , params : dict ) -> None :
1157+ for acq , acq_params , random_state in zip (
1158+ self .base_acquisitions ,
1159+ params ["base_acquisitions_params" ],
1160+ params ["random_states" ][:- 1 ]
1161+ ):
1162+ acq .set_acquisition_params (acq_params )
1163+ acq ._deserialize_random_state (random_state )
1164+
1165+ self .gains = np .array (params ["gains" ])
1166+ self .previous_candidates = (np .array (params ["previous_candidates" ])
1167+ if params ["previous_candidates" ] is not None
1168+ else None )
1169+
1170+ self ._deserialize_random_state (params ["random_states" ][- 1 ])
1171+
0 commit comments