Skip to content

Commit b97c11e

Browse files
committed
Go back to wrap_kernel
1 parent 264b79e commit b97c11e

File tree

3 files changed

+41
-58
lines changed

3 files changed

+41
-58
lines changed

bayes_opt/bayesian_optimization.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from bayes_opt.domain_reduction import DomainTransformer
1818
from bayes_opt.event import DEFAULT_EVENTS, Events
1919
from bayes_opt.logger import _get_default_logger
20-
from bayes_opt.parameter import WrappedKernel
20+
from bayes_opt.parameter import wrap_kernel
2121
from bayes_opt.target_space import TargetSpace
2222
from bayes_opt.util import ensure_rng
2323

@@ -153,7 +153,7 @@ def __init__(
153153

154154
# Internal GP regressor
155155
self._gp = GaussianProcessRegressor(
156-
kernel=WrappedKernel(Matern(nu=2.5), transform=self._space.kernel_transform),
156+
kernel=wrap_kernel(Matern(nu=2.5), transform=self._space.kernel_transform),
157157
alpha=1e-6,
158158
normalize_y=True,
159159
n_restarts_optimizer=5,
@@ -330,7 +330,5 @@ def set_bounds(self, new_bounds: BoundsMapping) -> None:
330330
def set_gp_params(self, **params: Any) -> None:
331331
"""Set parameters of the internal Gaussian Process Regressor."""
332332
if "kernel" in params:
333-
params["kernel"] = WrappedKernel(
334-
base_kernel=params["kernel"], transform=self._space.kernel_transform
335-
)
333+
params["kernel"] = wrap_kernel(kernel=params["kernel"], transform=self._space.kernel_transform)
336334
self._gp.set_params(**params)

bayes_opt/constraint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sklearn.gaussian_process import GaussianProcessRegressor
1010
from sklearn.gaussian_process.kernels import Matern
1111

12-
from bayes_opt.parameter import WrappedKernel
12+
from bayes_opt.parameter import wrap_kernel
1313

1414
if TYPE_CHECKING:
1515
from collections.abc import Callable
@@ -71,7 +71,7 @@ def __init__(
7171

7272
self._model = [
7373
GaussianProcessRegressor(
74-
kernel=WrappedKernel(Matern(nu=2.5), transform) if transform is not None else Matern(nu=2.5),
74+
kernel=wrap_kernel(Matern(nu=2.5), transform) if transform is not None else Matern(nu=2.5),
7575
alpha=1e-6,
7676
normalize_y=True,
7777
n_restarts_optimizer=5,

bayes_opt/parameter.py

Lines changed: 36 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import abc
66
from collections.abc import Sequence
7+
from inspect import signature
78
from numbers import Number
89
from typing import TYPE_CHECKING, Any, Callable, Union
910

@@ -458,69 +459,53 @@ def dim(self) -> int:
458459
return len(self.categories)
459460

460461

461-
class WrappedKernel(kernels.Kernel):
462-
"""Wrap a kernel with a parameter transformation.
463-
464-
The transform function is applied to the input before passing it to the base kernel.
462+
def wrap_kernel(kernel: kernels.Kernel, transform: Callable[[Any], Any]) -> kernels.Kernel:
463+
"""Wrap a kernel to transform input data before passing it to the kernel.
465464
466465
Parameters
467466
----------
468-
base_kernel : kernels.Kernel
467+
kernel : kernels.Kernel
468+
The kernel to wrap.
469469
470-
transform : Callable[[Any], Any]
471-
"""
470+
transform : Callable
471+
The transformation function to apply to the input data.
472472
473-
def __init__(self, base_kernel: kernels.Kernel, transform: Callable[[Any], Any]) -> None:
474-
super().__init__()
475-
self.base_kernel = base_kernel
476-
self.transform = transform
473+
Returns
474+
-------
475+
kernels.Kernel
476+
The wrapped kernel.
477477
478-
def __call__(self, X: NDArray[Float], Y: NDArray[Float] = None, eval_gradient: bool = False) -> Any:
479-
"""Return the kernel k(X, Y) and optionally its gradient after applying the transform.
478+
Notes
479+
-----
480+
See https://arxiv.org/abs/1805.03463 for more information.
481+
"""
482+
kernel_type = type(kernel)
480483

481-
For details, see the documentation of the base kernel.
484+
class WrappedKernel(kernel_type):
485+
@copy_signature(getattr(kernel_type.__init__, "deprecated_original", kernel_type.__init__))
486+
def __init__(self, **kwargs: Any) -> None:
487+
super().__init__(**kwargs)
482488

483-
Parameters
484-
----------
485-
X : ndarray of shape (n_samples_X, n_features)
486-
Left argument of the returned kernel k(X, Y).
489+
def __call__(self, X: Any, Y: Any = None, eval_gradient: bool = False) -> Any:
490+
X = transform(X)
491+
Y = transform(Y) if Y is not None else None
492+
return super().__call__(X, Y, eval_gradient)
487493

488-
Y : ndarray of shape (n_samples_Y, n_features), default=None
489-
Right argument of the returned kernel k(X, Y). If None, k(X, X) is evaluated.
494+
def __reduce__(self) -> str | tuple[Any, ...]:
495+
return (wrap_kernel, (kernel, transform))
490496

491-
eval_gradient : bool, default=False
492-
Determines whether the gradient with respect to the kernel hyperparameter is calculated.
497+
return WrappedKernel(**kernel.get_params())
493498

494-
Returns
495-
-------
496-
K : ndarray of shape (n_samples_X, n_samples_Y)
497-
498-
K_gradient : ndarray of shape (n_samples_X, n_samples_X, n_dims)
499-
"""
500-
X = self.transform(X)
501-
Y = self.transform(Y) if Y is not None else None
502-
return self.base_kernel(X, Y, eval_gradient)
503499

504-
def is_stationary(self):
505-
"""Return whether the kernel is stationary."""
506-
return self.base_kernel.is_stationary()
507-
508-
def diag(self, X: NDArray[Float]) -> NDArray[Float]:
509-
"""Return the diagonal of k(X, X).
510-
511-
This method allows for more efficient calculations than calling
512-
np.diag(self(X)).
500+
def copy_signature(source_fct: Callable[..., Any]) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
501+
"""Clones a signature from a source function to a target function.
513502
503+
via
504+
https://stackoverflow.com/a/58989918/
505+
"""
514506

515-
Parameters
516-
----------
517-
X : array-like of shape (n_samples,)
518-
Left argument of the returned kernel k(X, Y)
507+
def copy(target_fct: Callable[..., Any]) -> Callable[..., Any]:
508+
target_fct.__signature__ = signature(source_fct)
509+
return target_fct
519510

520-
Returns
521-
-------
522-
K_diag : ndarray of shape (n_samples_X,)
523-
Diagonal of kernel k(X, X)
524-
"""
525-
X = self.transform(X)
526-
return self.base_kernel.diag(X)
511+
return copy

0 commit comments

Comments
 (0)