Skip to content

Commit 0e859df

Browse files
committed
Refactor differential evolution import and optimize acquisition function
1 parent fb101a6 commit 0e859df

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

bayes_opt/acquisition.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import numpy as np
2929
from numpy.random import RandomState
30-
from scipy.optimize import differential_evolution, minimize
30+
from scipy.optimize._differentialevolution import DifferentialEvolutionSolver, minimize
3131
from scipy.special import softmax
3232
from scipy.stats import norm
3333
from sklearn.gaussian_process import GaussianProcessRegressor
@@ -297,12 +297,6 @@ def _smart_minimize(
297297
"""
298298
continuous_dimensions = space.continuous_dimensions
299299
continuous_bounds = space.bounds[continuous_dimensions]
300-
discrete_dimensions = ~continuous_dimensions
301-
302-
# if not continuous_dimensions.any():
303-
# min_acq = np.inf
304-
# x_min = np.array([np.nan] * space.bounds.shape[0])
305-
# return x_min, min_acq
306300

307301
min_acq: float | None = None
308302
x_try: NDArray[Float]
@@ -326,20 +320,17 @@ def _smart_minimize(
326320
ntrials = max(1, len(x_seeds) // 100)
327321
for _ in range(ntrials):
328322
xinit = space.random_sample(15 * len(space.bounds), random_state=self.random_state)
329-
res: OptimizeResult = differential_evolution(
330-
acq,
331-
bounds=space.bounds,
332-
init=xinit,
333-
integrality=discrete_dimensions,
334-
rng=self.random_state,
335-
)
323+
de = DifferentialEvolutionSolver(acq, bounds=space.bounds, init=xinit, rng=self.random_state)
324+
res: OptimizeResult = de.solve()
325+
336326
# See if success
337327
if not res.success:
338328
continue
339329

340330
# Store it if better than previous minimum(maximum).
341331
if min_acq is None or np.squeeze(res.fun) >= min_acq:
342-
x_try = res.x
332+
x_try_sc = de._unscale_parameters(res.x)
333+
x_try = space.kernel_transform(x_try_sc).flatten()
343334
x_min = x_try
344335
min_acq = np.squeeze(res.fun)
345336

0 commit comments

Comments
 (0)