Skip to content

Commit 8688445

Browse files
authored
Add .predict function to optimizer (#593)
* Add `.predict` function to optimizer * Allow all iterables of dicts
1 parent 4bca224 commit 8688445

File tree

2 files changed

+455
-47
lines changed

2 files changed

+455
-47
lines changed

bayes_opt/bayesian_optimization.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import json
1010
from collections import deque
11+
from collections.abc import Iterable
1112
from os import PathLike
1213
from pathlib import Path
1314
from typing import TYPE_CHECKING, Any
@@ -175,6 +176,92 @@ def res(self) -> list[dict[str, Any]]:
175176
"""
176177
return self._space.res()
177178

179+
def predict(
180+
self,
181+
params: dict[str, Any] | Iterable[dict[str, Any]],
182+
return_std=False,
183+
return_cov=False,
184+
fit_gp=True,
185+
) -> float | NDArray[Float] | tuple[float | NDArray[Float], float | NDArray[Float]]:
186+
"""Predict the target function value at given parameters.
187+
188+
Parameters
189+
---------
190+
params: dict or iterable of dicts
191+
The parameters where the prediction is made.
192+
193+
return_std: bool, optional(default=False)
194+
If True, the standard deviation of the prediction is returned.
195+
196+
return_cov: bool, optional(default=False)
197+
If True, the covariance of the prediction is returned.
198+
199+
fit_gp: bool, optional(default=True)
200+
If True, the internal Gaussian Process model is fitted before
201+
making the prediction.
202+
203+
Returns
204+
-------
205+
mean: float or np.ndarray
206+
The predicted mean of the target function at the given parameters.
207+
When params is a dict, returns a scalar. When params is an iterable,
208+
returns a 1D array.
209+
210+
std_or_cov: float or np.ndarray (only if return_std or return_cov is True)
211+
The predicted standard deviation or covariance of the target function
212+
at the given parameters.
213+
"""
214+
# Validate param types
215+
if isinstance(params, dict):
216+
params_array = self._space.params_to_array(params).reshape(1, -1)
217+
single_param = True
218+
elif isinstance(params, Iterable) and not isinstance(params, str):
219+
# convert iterable of dicts to 2D array
220+
params_array = np.array([self._space.params_to_array(p) for p in params])
221+
single_param = False
222+
else:
223+
msg = f"params must be a dict or iterable of dicts, got {type(params).__name__}"
224+
raise TypeError(msg)
225+
226+
# Validate mutual exclusivity of return_std and return_cov
227+
if return_std and return_cov:
228+
msg = "return_std and return_cov cannot both be True"
229+
raise ValueError(msg)
230+
231+
if fit_gp:
232+
if len(self._space) == 0:
233+
msg = (
234+
"The Gaussian Process model cannot be fitted with zero observations. To use predict(), "
235+
"without fitting the GP, set fit_gp=False. The predictions will then be made using the "
236+
"GP prior."
237+
)
238+
raise RuntimeError(msg)
239+
self.acquisition_function._fit_gp(self._gp, self._space)
240+
241+
res = self._gp.predict(params_array, return_std=return_std, return_cov=return_cov)
242+
243+
if return_std or return_cov:
244+
mean, std_or_cov = res
245+
else:
246+
mean = res
247+
248+
# Shape semantics: dict input returns scalars, list input returns arrays
249+
# Ensure list input always returns arrays (convert scalar to 1D if needed)
250+
if not single_param and mean.ndim == 0:
251+
mean = np.atleast_1d(mean)
252+
# ruff complains when nesting conditionals, so this three-way split is necessary
253+
if not single_param and (return_std or return_cov) and std_or_cov.ndim == 0:
254+
std_or_cov = np.atleast_1d(std_or_cov)
255+
256+
if single_param and mean.ndim > 0:
257+
mean = mean[0]
258+
if single_param and return_std and std_or_cov.ndim > 0:
259+
std_or_cov = std_or_cov[0]
260+
261+
if return_std or return_cov:
262+
return mean, std_or_cov
263+
return mean
264+
178265
def register(
179266
self, params: ParamsType, target: float, constraint_value: float | NDArray[Float] | None = None
180267
) -> None:
@@ -303,8 +390,8 @@ def maximize(self, init_points: int = 5, n_iter: int = 25) -> None:
303390
probe based on the acquisition function. This means that the GP may
304391
not be fitted on all points registered to the target space when the
305392
method completes. If you intend to use the GP model after the
306-
optimization routine, make sure to fit it manually, e.g. by calling
307-
``optimizer._gp.fit(optimizer.space.params, optimizer.space.target)``.
393+
optimization routine, make sure to call predict() with fit_gp=True.
394+
308395
"""
309396
# Log optimization start
310397
self.logger.log_optimization_start(self._space.keys)

0 commit comments

Comments
 (0)