99
1010class BayesianOptimization (object ):
1111
12- def __init__ (self , f , pbounds , verbose = 1 ):
12+ def __init__ (self , f , pbounds , random_state = None , verbose = 1 ):
1313 """
1414 :param f:
1515 Function to be maximized.
@@ -25,6 +25,13 @@ def __init__(self, f, pbounds, verbose=1):
2525 # Store the original dictionary
2626 self .pbounds = pbounds
2727
28+ if random_state is None :
29+ self .random_state = np .random .RandomState ()
30+ elif isinstance (random_state , int ):
31+ self .random_state = np .random .RandomState (random_state )
32+ else :
33+ self .random_state = random_state
34+
2835 # Get the name of the parameters
2936 self .keys = list (pbounds .keys ())
3037
@@ -59,6 +66,7 @@ def __init__(self, f, pbounds, verbose=1):
5966 self .gp = GaussianProcessRegressor (
6067 kernel = Matern (nu = 2.5 ),
6168 n_restarts_optimizer = 25 ,
69+ random_state = self .random_state
6270 )
6371
6472 # Utility Function placeholder
@@ -87,7 +95,7 @@ def init(self, init_points):
8795 """
8896
8997 # Generate random points
90- l = [np . random .uniform (x [0 ], x [1 ], size = init_points )
98+ l = [self . random_state .uniform (x [0 ], x [1 ], size = init_points )
9199 for x in self .bounds ]
92100
93101 # Concatenate new random points to possible existing
@@ -274,7 +282,8 @@ def maximize(self,
274282 x_max = acq_max (ac = self .util .utility ,
275283 gp = self .gp ,
276284 y_max = y_max ,
277- bounds = self .bounds )
285+ bounds = self .bounds ,
286+ random_state = self .random_state )
278287
279288 # Print new header
280289 if self .verbose :
@@ -291,7 +300,7 @@ def maximize(self,
291300 pwarning = False
292301 if np .any ((self .X - x_max ).sum (axis = 1 ) == 0 ):
293302
294- x_max = np . random .uniform (self .bounds [:, 0 ],
303+ x_max = self . random_state .uniform (self .bounds [:, 0 ],
295304 self .bounds [:, 1 ],
296305 size = self .bounds .shape [0 ])
297306
@@ -313,7 +322,8 @@ def maximize(self,
313322 x_max = acq_max (ac = self .util .utility ,
314323 gp = self .gp ,
315324 y_max = y_max ,
316- bounds = self .bounds )
325+ bounds = self .bounds ,
326+ random_state = self .random_state )
317327
318328 # Print stuff
319329 if self .verbose :
0 commit comments