@@ -69,6 +69,33 @@ def __init__(self, random_state: int | RandomState | None = None) -> None:
6969 self .random_state = RandomState ()
7070 self .i = 0
7171
72+ def _serialize_random_state (self ) -> dict | None :
73+ """Convert random state to JSON serializable format."""
74+ if self .random_state is not None :
75+ state = self .random_state .get_state ()
76+ return {
77+ "bit_generator" : state [0 ],
78+ "state" : state [1 ].tolist (), # Convert numpy array to list
79+ "pos" : state [2 ],
80+ "has_gauss" : state [3 ],
81+ "cached_gaussian" : state [4 ],
82+ }
83+ return None
84+
85+ def _deserialize_random_state (self , state_dict : dict | None ) -> None :
86+ """Restore random state from JSON serializable format."""
87+ if state_dict is not None :
88+ if self .random_state is None :
89+ self .random_state = RandomState ()
90+ state = (
91+ state_dict ["bit_generator" ],
92+ np .array (state_dict ["state" ], dtype = np .uint32 ),
93+ state_dict ["pos" ],
94+ state_dict ["has_gauss" ],
95+ state_dict ["cached_gaussian" ],
96+ )
97+ self .random_state .set_state (state )
98+
7299 @abc .abstractmethod
73100 def base_acq (self , * args : Any , ** kwargs : Any ) -> NDArray [Float ]:
74101 """Provide access to the base acquisition function."""
@@ -82,6 +109,34 @@ def _fit_gp(self, gp: GaussianProcessRegressor, target_space: TargetSpace) -> No
82109 if target_space .constraint is not None :
83110 target_space .constraint .fit (target_space .params , target_space ._constraint_values )
84111
112+ def get_acquisition_params (self ) -> dict [str , Any ]:
113+ """
114+ Get the parameters of the acquisition function.
115+
116+ Returns
117+ -------
118+ dict
119+ The parameters of the acquisition function.
120+ """
121+ error_msg = (
122+ "Custom AcquisitionFunction subclasses must implement their own get_acquisition_params method."
123+ )
124+ raise NotImplementedError (error_msg )
125+
126+ def set_acquisition_params (self , ** params ) -> None :
127+ """
128+ Set the parameters of the acquisition function.
129+
130+ Parameters
131+ ----------
132+ **params : dict
133+ The parameters of the acquisition function.
134+ """
135+ error_msg = (
136+ "Custom AcquisitionFunction subclasses must implement their own set_acquisition_params method."
137+ )
138+ raise NotImplementedError (error_msg )
139+
85140 def suggest (
86141 self ,
87142 gp : GaussianProcessRegressor ,
@@ -462,6 +517,34 @@ def decay_exploration(self) -> None:
462517 ):
463518 self .kappa = self .kappa * self .exploration_decay
464519
520+ def get_acquisition_params (self ) -> dict :
521+ """Get the current acquisition function parameters.
522+
523+ Returns
524+ -------
525+ dict
526+ Dictionary containing the current acquisition function parameters.
527+ """
528+ return {
529+ "kappa" : self .kappa ,
530+ "exploration_decay" : self .exploration_decay ,
531+ "exploration_decay_delay" : self .exploration_decay_delay ,
532+ "random_state" : self ._serialize_random_state (),
533+ }
534+
535+ def set_acquisition_params (self , params : dict ) -> None :
536+ """Set the acquisition function parameters.
537+
538+ Parameters
539+ ----------
540+ params : dict
541+ Dictionary containing the acquisition function parameters.
542+ """
543+ self .kappa = params ["kappa" ]
544+ self .exploration_decay = params ["exploration_decay" ]
545+ self .exploration_decay_delay = params ["exploration_decay_delay" ]
546+ self ._deserialize_random_state (params ["random_state" ])
547+
465548
466549class ProbabilityOfImprovement (AcquisitionFunction ):
467550 r"""Probability of Improvement acqusition function.
@@ -595,6 +678,34 @@ def decay_exploration(self) -> None:
595678 ):
596679 self .xi = self .xi * self .exploration_decay
597680
681+ def get_acquisition_params (self ) -> dict :
682+ """Get the current acquisition function parameters.
683+
684+ Returns
685+ -------
686+ dict
687+ Dictionary containing the current acquisition function parameters.
688+ """
689+ return {
690+ "xi" : self .xi ,
691+ "exploration_decay" : self .exploration_decay ,
692+ "exploration_decay_delay" : self .exploration_decay_delay ,
693+ "random_state" : self ._serialize_random_state (),
694+ }
695+
696+ def set_acquisition_params (self , params : dict ) -> None :
697+ """Set the acquisition function parameters.
698+
699+ Parameters
700+ ----------
701+ params : dict
702+ Dictionary containing the acquisition function parameters.
703+ """
704+ self .xi = params ["xi" ]
705+ self .exploration_decay = params ["exploration_decay" ]
706+ self .exploration_decay_delay = params ["exploration_decay_delay" ]
707+ self ._deserialize_random_state (params ["random_state" ])
708+
598709
599710class ExpectedImprovement (AcquisitionFunction ):
600711 r"""Expected Improvement acqusition function.
@@ -736,6 +847,34 @@ def decay_exploration(self) -> None:
736847 ):
737848 self .xi = self .xi * self .exploration_decay
738849
850+ def get_acquisition_params (self ) -> dict :
851+ """Get the current acquisition function parameters.
852+
853+ Returns
854+ -------
855+ dict
856+ Dictionary containing the current acquisition function parameters.
857+ """
858+ return {
859+ "xi" : self .xi ,
860+ "exploration_decay" : self .exploration_decay ,
861+ "exploration_decay_delay" : self .exploration_decay_delay ,
862+ "random_state" : self ._serialize_random_state (),
863+ }
864+
865+ def set_acquisition_params (self , params : dict ) -> None :
866+ """Set the acquisition function parameters.
867+
868+ Parameters
869+ ----------
870+ params : dict
871+ Dictionary containing the acquisition function parameters.
872+ """
873+ self .xi = params ["xi" ]
874+ self .exploration_decay = params ["exploration_decay" ]
875+ self .exploration_decay_delay = params ["exploration_decay_delay" ]
876+ self ._deserialize_random_state (params ["random_state" ])
877+
739878
740879class ConstantLiar (AcquisitionFunction ):
741880 """Constant Liar acquisition function.
@@ -926,6 +1065,38 @@ def suggest(
9261065
9271066 return x_max
9281067
1068+ def get_acquisition_params (self ) -> dict :
1069+ """Get the current acquisition function parameters.
1070+
1071+ Returns
1072+ -------
1073+ dict
1074+ Dictionary containing the current acquisition function parameters.
1075+ """
1076+ return {
1077+ "dummies" : [dummy .tolist () for dummy in self .dummies ],
1078+ "base_acquisition_params" : self .base_acquisition .get_acquisition_params (),
1079+ "strategy" : self .strategy ,
1080+ "atol" : self .atol ,
1081+ "rtol" : self .rtol ,
1082+ "random_state" : self ._serialize_random_state (),
1083+ }
1084+
1085+ def set_acquisition_params (self , params : dict ) -> None :
1086+ """Set the acquisition function parameters.
1087+
1088+ Parameters
1089+ ----------
1090+ params : dict
1091+ Dictionary containing the acquisition function parameters.
1092+ """
1093+ self .dummies = [np .array (dummy ) for dummy in params ["dummies" ]]
1094+ self .base_acquisition .set_acquisition_params (params ["base_acquisition_params" ])
1095+ self .strategy = params ["strategy" ]
1096+ self .atol = params ["atol" ]
1097+ self .rtol = params ["rtol" ]
1098+ self ._deserialize_random_state (params ["random_state" ])
1099+
9291100
9301101class GPHedge (AcquisitionFunction ):
9311102 """GPHedge acquisition function.
@@ -1043,3 +1214,38 @@ def suggest(
10431214 self .previous_candidates = np .array (x_max )
10441215 idx = self ._sample_idx_from_softmax_gains ()
10451216 return x_max [idx ]
1217+
1218+ def get_acquisition_params (self ) -> dict :
1219+ """Get the current acquisition function parameters.
1220+
1221+ Returns
1222+ -------
1223+ dict
1224+ Dictionary containing the current acquisition function parameters.
1225+ """
1226+ return {
1227+ "base_acquisitions_params" : [acq .get_acquisition_params () for acq in self .base_acquisitions ],
1228+ "gains" : self .gains .tolist (),
1229+ "previous_candidates" : self .previous_candidates .tolist ()
1230+ if self .previous_candidates is not None
1231+ else None ,
1232+ "gphedge_random_state" : self ._serialize_random_state (),
1233+ }
1234+
1235+ def set_acquisition_params (self , params : dict ) -> None :
1236+ """Set the acquisition function parameters.
1237+
1238+ Parameters
1239+ ----------
1240+ params : dict
1241+ Dictionary containing the acquisition function parameters.
1242+ """
1243+ for acq , acq_params in zip (self .base_acquisitions , params ["base_acquisitions_params" ]):
1244+ acq .set_acquisition_params (acq_params )
1245+
1246+ self .gains = np .array (params ["gains" ])
1247+ self .previous_candidates = (
1248+ np .array (params ["previous_candidates" ]) if params ["previous_candidates" ] is not None else None
1249+ )
1250+
1251+ self ._deserialize_random_state (params ["gphedge_random_state" ])
0 commit comments