Skip to content

Commit 1e51b28

Browse files
committed
Add .predict function to optimizer
1 parent c410d51 commit 1e51b28

File tree

2 files changed

+332
-47
lines changed

2 files changed

+332
-47
lines changed

bayes_opt/bayesian_optimization.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,75 @@ def res(self) -> list[dict[str, Any]]:
175175
"""
176176
return self._space.res()
177177

178+
def predict(
179+
self, params: dict[str, Any] | list[dict[str, Any]], return_std=False, return_cov=False, fit_gp=True
180+
) -> tuple[float | NDArray[Float], float | NDArray[Float]]:
181+
"""Predict the target function value at given parameters.
182+
183+
Parameters
184+
---------
185+
params: dict or list
186+
The parameters where the prediction is made.
187+
188+
return_std: bool, optional(default=True)
189+
If True, the standard deviation of the prediction is returned.
190+
191+
return_cov: bool, optional(default=False)
192+
If True, the covariance of the prediction is returned.
193+
194+
fit_gp: bool, optional(default=True)
195+
If True, the internal Gaussian Process model is fitted before
196+
making the prediction.
197+
198+
Returns
199+
-------
200+
mean: float or np.ndarray
201+
The predicted mean of the target function at the given parameters.
202+
203+
std_or_cov: float or np.ndarray
204+
The predicted standard deviation or covariance of the target function
205+
at the given parameters.
206+
"""
207+
if isinstance(params, list):
208+
# convert list of dicts to 2D array
209+
params_array = np.array([self._space.params_to_array(p) for p in params])
210+
single_param = False
211+
elif isinstance(params, dict):
212+
params_array = self._space.params_to_array(params).reshape(1, -1)
213+
single_param = True
214+
215+
if fit_gp:
216+
if len(self._space) == 0:
217+
msg = (
218+
"The Gaussian Process model cannot be fitted with zero observations. To use predict(), "
219+
"without fitting the GP, set fit_gp=False. The predictions will then be made using the "
220+
"GP prior."
221+
)
222+
raise RuntimeError(msg)
223+
self.acquisition_function._fit_gp(self._gp, self._space)
224+
225+
res = self._gp.predict(params_array, return_std=return_std, return_cov=return_cov)
226+
227+
if return_std or return_cov:
228+
mean, std_or_cov = res
229+
else:
230+
mean = res
231+
232+
if not single_param and mean.ndim == 0:
233+
mean = np.atleast_1d(mean)
234+
# ruff complains when nesting conditionals, so this three-way split is necessary
235+
if not single_param and (return_std or return_cov) and std_or_cov.ndim == 0:
236+
std_or_cov = np.atleast_1d(std_or_cov)
237+
238+
if single_param and mean.ndim > 0:
239+
mean = mean[0]
240+
if single_param and (return_std or return_cov) and std_or_cov.ndim > 0:
241+
std_or_cov = std_or_cov[0]
242+
243+
if return_std or return_cov:
244+
return mean, std_or_cov
245+
return mean
246+
178247
def register(
179248
self, params: ParamsType, target: float, constraint_value: float | NDArray[Float] | None = None
180249
) -> None:
@@ -303,8 +372,8 @@ def maximize(self, init_points: int = 5, n_iter: int = 25) -> None:
303372
probe based on the acquisition function. This means that the GP may
304373
not be fitted on all points registered to the target space when the
305374
method completes. If you intend to use the GP model after the
306-
optimization routine, make sure to fit it manually, e.g. by calling
307-
``optimizer._gp.fit(optimizer.space.params, optimizer.space.target)``.
375+
optimization routine, make sure to call predict() with fit_gp=True.
376+
308377
"""
309378
# Log optimization start
310379
self.logger.log_optimization_start(self._space.keys)

0 commit comments

Comments
 (0)